diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 40b265b539..e7919043ce 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -44,6 +44,7 @@ def moe_sorting( device = topk_ids.device M, topk = topk_ids.shape max_num_tokens_padded = topk_ids.numel() + num_experts * block_size - topk + max_num_m_blocks = int((max_num_tokens_padded + block_size - 1) // block_size) sorted_ids = torch.empty((max_num_tokens_padded,), dtype=dtypes.i32, device=device) sorted_weights = torch.empty( @@ -104,6 +105,11 @@ def fused_moe( num_local_tokens: Optional[torch.tensor] = None, moe_sorting_dispatch_policy=0, dtype=None, + # following for cktile support + hidden_pad=0, + intermediate_pad=0, + bias1=None, + bias2=None, ): if not block_size_M: block_size_M = -1 @@ -125,6 +131,10 @@ def fused_moe( num_local_tokens=num_local_tokens, moe_sorting_dispatch_policy=moe_sorting_dispatch_policy, dtype=dtype, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -178,6 +188,10 @@ def fused_moe_( num_local_tokens: Optional[torch.Tensor] = None, moe_sorting_dispatch_policy: bool = 0, dtype: Optional[torch.dtype] = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: Optional[torch.Tensor] = None, + bias2: Optional[torch.Tensor] = None, ) -> torch.Tensor: # We do such convert since custom_op schema restriction on block_size_M, and Enum type activation = ActivationType(activation) @@ -220,6 +234,10 @@ def fused_moe_( isG1U1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) block_size_M = metadata.block_m if block_size_M is None else block_size_M @@ -252,6 +270,8 @@ def fused_moe_( moe_buf, isG1U1, block_size_M, + # activation=activation, + # quant_type=quant_type, q_dtype_a=q_dtype_a, q_dtype_w=q_dtype_w, w1_scale=w1_scale, @@ -283,6 +303,11 @@ def fused_moe_( a1_scale=a1_scale, a2_scale=a2_scale, num_local_tokens=num_local_tokens, + # following for cktile support + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -491,6 +516,10 @@ def get_2stage_cfgs( use_g1u1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ): def get_cfg_2stages(tune_file): import pandas as pd @@ -545,7 +574,6 @@ def MainFunc(): f.write( "token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1" ) - q_dtype_ws = q_dtype_w if q_dtype_w != torch.uint32 else "torch.int4" f.write( f"\n{token},{model_dim},{inter_dim},{expert},{topk},{activation},{dtype},{q_dtype_a},{q_dtype_ws},{q_type},{int(use_g1u1)},{int(doweight_stage1)}" @@ -624,6 +652,28 @@ def FinalFunc(): ksplit, run_1stage, ) + if ( + dtype in [dtypes.bf16, dtypes.fp16] + and q_type == QuantType.per_1x32 + and activation == ActivationType.Swiglu + ): + return MOEMetadata( + functools.partial( + cktile_moe_stage1, + n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1), + k_pad_zeros=hidden_pad // 128 * 128, + bias1=bias1, + ), + functools.partial( + cktile_moe_stage2, + n_pad_zeros=hidden_pad // 64 * 64, + k_pad_zeros=intermediate_pad // 128 * 128, + bias2=bias2, + ), + 16 if token < 2048 else 32, + ksplit, + False, + ) if ( "ck2stages" in kernelName1 or (q_type == QuantType.per_1x128 and doweight_stage1) @@ -701,6 +751,11 @@ def fused_moe_2stages( a1_scale=None, # [expert(local_expert:EP), 1, model_dim] a2_scale=None, # [expert(local_expert:EP), 1, inter_dim] num_local_tokens: Optional[torch.tensor] = None, + # following for cktile support + hidden_pad=0, + intermediate_pad=0, + bias1=None, + bias2=None, ): quant_func = get_quant(quant_type) @@ -708,7 +763,6 @@ def fused_moe_2stages( E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) dtype = moe_out.dtype device = hidden_states.device - metadata = get_2stage_cfgs( get_padded_M(token_num), # consider token_num > 1024 as prefill model_dim, @@ -722,9 +776,20 @@ def fused_moe_2stages( isG1U1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) - - if quant_type == QuantType.per_1x32: + if ( + quant_type == QuantType.per_1x32 + and dtype in [dtypes.bf16, dtypes.fp16] + and w1.dtype == dtypes.fp4x2 + and activation == ActivationType.Swiglu + ): + a1 = hidden_states.to(dtype) + a1_scale = None + elif quant_type == QuantType.per_1x32: a1, a1_scale = quant_func( hidden_states, scale=a1_scale, @@ -781,7 +846,14 @@ def fused_moe_2stages( sorted_weights=sorted_weights if doweight_stage1 else None, ) - if quant_type == QuantType.per_1x32: + if ( + quant_type == QuantType.per_1x32 + and dtype in [dtypes.bf16, dtypes.fp16] + and w1.dtype == dtypes.fp4x2 + and activation == ActivationType.Swiglu + ): + a2_scale = None + elif quant_type == QuantType.per_1x32: a2 = a2.view(-1, inter_dim) a2, a2_scale = quant_func( a2, @@ -972,6 +1044,16 @@ def torch_moe( return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype) +# temp workaround for swiglu +def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0): + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + + def torch_moe_stage1( hidden_states, w1, # E, inter_dim*2, model_dim @@ -984,6 +1066,7 @@ def torch_moe_stage1( # following for quant a1_scale=None, # [token, 1] w1_scale=None, # [expert, inter_dim, 1] + w1_bias=None, # [expert, inter_dim, 1] doweight=False, ): quant_type = quant_remap.get(quant_type, quant_type) @@ -995,10 +1078,14 @@ def torch_moe_stage1( if quant_type == QuantType.per_1x32: from aiter.utility import fp4_utils - hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) w1 = fp4_utils.mxfp4_to_f32(w1) w1_scale = fp4_utils.e8m0_to_f32(w1_scale) - a1_scale = fp4_utils.e8m0_to_f32(a1_scale) + if a1_scale is not None: # skip a16w4 + hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) + a1_scale = fp4_utils.e8m0_to_f32(a1_scale) + else: # a16w4 + hidden_states = hidden_states.to(ctype) + else: hidden_states = hidden_states.to(ctype) w1 = w1.to(ctype) @@ -1006,8 +1093,8 @@ def torch_moe_stage1( if quant_type in [QuantType.per_Token, QuantType.per_Tensor]: w1 = w1 * w1_scale.view(w1_scale.shape[0], -1, 1) hidden_states = hidden_states * a1_scale - # per_1x128 - elif quant_type == QuantType.per_1x128: + # per_128x128 + elif quant_type in [QuantType.per_128x128, QuantType.per_1x128]: w1_shape = w1.shape w1 = w1.view( w1.shape[0], w1.shape[1] // 128, 128, w1.shape[2] // 128, 128 @@ -1031,9 +1118,12 @@ def torch_moe_stage1( w1 = w1.view(w1_shape) a1_shape = hidden_states.shape - a1_scale = a1_scale[: a1_shape[0]] hidden_states = hidden_states.view(a1_shape[0], a1_shape[1] // 32, 32) - hidden_states = hidden_states * a1_scale.view(a1_shape[0], a1_shape[1] // 32, 1) + if a1_scale is not None: + a1_scale = a1_scale[: a1_shape[0]] + hidden_states = hidden_states * a1_scale.view( + a1_shape[0], a1_shape[1] // 32, 1 + ) hidden_states = hidden_states.view(a1_shape) else: assert False, f"Unsupported quant_type: {quant_type}" @@ -1053,11 +1143,17 @@ def torch_moe_stage1( if doweight: act_input = act_input * topk_weight[mask].view(-1, 1) out[mask] = act_input + if w1_bias is not None: + out[mask] = out[mask] + w1_bias[E_id].view(1, -1) use_g1u1 = w1.shape[1] == (2 * inter_dim) + use_swiglu = (a1_scale is None) and (quant_type == QuantType.per_1x32) torch_act = aiter.get_torch_act(activation) if use_g1u1: gate, up = out.split([inter_dim, inter_dim], dim=-1) - out = torch_act(gate) * up + if use_swiglu: + out = swiglu(gate, up) + else: + out = torch_act(gate) * up else: out = torch_act(out) return out.to(dtype) @@ -1073,18 +1169,21 @@ def torch_moe_stage2( quant_type=QuantType.No, w2_scale=None, # [1] a2_scale=None, # [expert]]' + w2_bias=None, doweight=True, ): - quant_type = quant_remap.get(quant_type, quant_type) ctype = dtypes.fp32 # compute type E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) if quant_type == QuantType.per_1x32: from aiter.utility import fp4_utils - hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) w2 = fp4_utils.mxfp4_to_f32(w2) w2_scale = fp4_utils.e8m0_to_f32(w2_scale) - a2_scale = fp4_utils.e8m0_to_f32(a2_scale) + if a2_scale is not None: + hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) + a2_scale = fp4_utils.e8m0_to_f32(a2_scale) + else: # a16w4 + hidden_states = hidden_states.to(ctype) else: hidden_states = hidden_states.to(ctype) w2 = w2.to(ctype) @@ -1095,7 +1194,7 @@ def torch_moe_stage2( if quant_type in [QuantType.per_Token, QuantType.per_Tensor]: hidden_states = hidden_states * a2_scale.view(a2_scale.shape[0], -1, 1) w2 = w2 * w2_scale.view(w2_scale.shape[0], -1, 1) - elif quant_type == QuantType.per_1x128: + elif quant_type in [QuantType.per_128x128, QuantType.per_1x128]: a2_scale = a2_scale.view(hidden_states.shape[0], topk, -1, 1) a2_scale = a2_scale.repeat(1, 1, 1, 128).view(hidden_states.shape[0], topk, -1) hidden_states = hidden_states * a2_scale @@ -1109,11 +1208,12 @@ def torch_moe_stage2( w2 = w2.view(w2_shape) elif quant_type == QuantType.per_1x32: a2_shape = hidden_states.shape - a2_scale = a2_scale[: a2_shape[0] * topk] - a2_scale = a2_scale.view(token_num, topk, inter_dim // 32, 1) - hidden_states = ( - hidden_states.view(token_num, topk, inter_dim // 32, 32) * a2_scale - ) + if a2_scale is not None: + a2_scale = a2_scale[: a2_shape[0] * topk] + a2_scale = a2_scale.view(token_num, topk, inter_dim // 32, 1) + hidden_states = ( + hidden_states.view(token_num, topk, inter_dim // 32, 32) * a2_scale + ) hidden_states = hidden_states.view(a2_shape) w2_shape = w2.shape @@ -1133,11 +1233,110 @@ def torch_moe_stage2( sub_tokens = hidden_states[mask] act_input = sub_tokens @ (w2[E_id].transpose(0, 1)) out[mask] = act_input + if w2_bias is not None: + out[mask] = out[mask] + w2_bias[E_id].view(1, -1) if doweight: out = out * topk_weights.view(token_num, -1, 1) return out.sum(1).to(dtype) +def cktile_moe_stage1( + hidden_states, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + block_m, + a1_scale, + w1_scale, + sorted_weights=None, + n_pad_zeros=0, + k_pad_zeros=0, + bias1=None, +): + token_num = hidden_states.shape[0] + _, n1, k1 = w1.shape + _, k2, n2 = w2.shape + D = n2 if k2 == k1 else n2 * 2 # bit4 format + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + if w1.dtype is torch.uint32: + D = D * 8 + out = torch.empty( + (token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device + ) + # print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0])) + aiter.moe_cktile2stages_gemm1( + hidden_states, + w1, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a1_scale, + w1_scale, + bias1, + block_m, + ) + return out + + +def cktile_moe_stage2( + a2, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + w2_scale, + a2_scale, + block_m, + sorted_weights=None, + zeros_out=False, + n_pad_zeros=0, + k_pad_zeros=0, + bias2=None, +): + token_num = a2.shape[0] + D = w2.shape[1] + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + # out = torch.empty( + # (token_num, D), + # dtype=a2.dtype, + # device=a2.device, + # ) + # if zeros_out: + # out.fill_(0) + # print("Run cktile_moe_stage2: M=%d, N=%d, K=%d, topk=%d, expert=%d"%(a2.shape[0]*a2.shape[1], w2.shape[1], a2.shape[2], topk, w2.shape[0])) + aiter.moe_cktile2stages_gemm2( + a2, + w2, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a2_scale, + w2_scale, + bias2, + block_m, + ) + return out + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 46494488ac..1fb118b592 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -273,14 +273,16 @@ "srcs": [ "f'{AITER_CSRC_DIR}/pybind/deepgemm_pybind.cu'", "f'{AITER_CSRC_DIR}/ck_deepgemm/deepgemm.cu'" - ], "flags_extra_cc": [], "flags_extra_hip": [], "md_name": "'module_deepgemm'", "extra_ldflags": "None", - "extra_include": ["f'{CK_DIR}/example/ck_tile/18_flatmm'", "f'{AITER_CSRC_DIR}/ck_deepgemm/include'"], - "verbose": "False", + "extra_include": [ + "f'{CK_DIR}/example/ck_tile/18_flatmm'", + "f'{AITER_CSRC_DIR}/ck_deepgemm/include'" + ], + "verbose": "False", "is_python_module": "True", "is_standalone": "False", "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", @@ -392,6 +394,24 @@ "hip_clang_path": "os.environ.get('GEMM_A4W4_BLOCKWISE_HIP_CLANG_PATH')", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py --working_path {{}}'" }, + "module_moe_cktile2stages": { + "srcs": [ + "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu'", + "f'{AITER_CSRC_DIR}/pybind/moe_cktile_2stages_pybind.cu'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "md_name": "'module_moe_cktile2stages'", + "extra_ldflags": "None", + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/include'" + ], + "verbose": "False", + "is_python_module": "True", + "is_standalone": "False", + "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/gen_instances.py --working_path {{}}'" + }, "module_moe_sorting": { "srcs": [ "f'{AITER_CSRC_DIR}/py_itfs_ck/moe_sorting_kernels.cu'", @@ -966,7 +986,8 @@ "module_mla_reduce": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/mla_reduce_pybind.cu'", - "f'{AITER_CSRC_DIR}/kernels/mla/reduce.cu'"], + "f'{AITER_CSRC_DIR}/kernels/mla/reduce.cu'" + ], "flags_extra_cc": [], "flags_extra_hip": [], "extra_ldflags": "None", @@ -974,4 +995,4 @@ "verbose": "False", "blob_gen_cmd": "''" } -} +} \ No newline at end of file diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index f70438b4e0..07fa8ff94c 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -313,6 +313,112 @@ def ck_moe_stage2( ) -> None: ... +@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm1") +def moe_cktile2stages_gemm1_ck( + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, +) -> Tensor: ... + + +def moe_cktile2stages_gemm1( + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, +): + return moe_cktile2stages_gemm1_ck( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + block_m, + ) + + +@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm2") +def moe_cktile2stages_gemm2_ck( + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, +) -> Tensor: ... + + +def moe_cktile2stages_gemm2( + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, +): + return moe_cktile2stages_gemm2_ck( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + block_m, + ) + + dtype2str_dict = { dtypes.fp16: "f16", dtypes.bf16: "b16", diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index 3d10076cd1..1ea0e35ac7 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -23,3 +23,89 @@ def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Te x_ = x_.contiguous() x_ = x_.view(*x.shape) return x_.view(x_type) + + +def shuffle_weight_NK( + x: torch.Tensor, inst_N: int, inst_K: int, use_int4=False +) -> torch.Tensor: + kPerLane = inst_K // (64 // inst_N) + if use_int4: + kPerLane *= 2 + assert ( + x.shape[-2] % inst_N == 0 + ), f"{x.shape[-2]} % {inst_N} == {x.shape[-2] % N_WARP_TILE }" + assert ( + x.shape[-1] % inst_K == 0 + ), f"{x.shape[-1]} % {inst_K} == {x.shape[-1] % K_WARP_TILE }" + + x_ = x + x_ = x_.view( + -1, x.shape[-2] // inst_N, inst_N, x.shape[-1] // inst_K, 64 // inst_N, kPerLane + ) + x_ = x_.permute(0, 1, 3, 4, 2, 5).contiguous() + return x_.view(*x.shape) + + +def shuffle_weight_a16w4(src: torch.Tensor, NLane: int, gate_up: bool) -> torch.Tensor: + """ + src: shape [experts_cnt, N, K_pk], where K_pk = K // 2 + Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] + """ + # print("gemm shape:", src.shape) + src_type = src.dtype + if hasattr(torch, "float4_e2m1fn_x2") and src_type == torch.float4_e2m1fn_x2: + src = src.view(torch.uint8) + experts_cnt, N, K_pk = src.shape + if gate_up: + N = N // 2 + KPack = 16 + KLane = 64 // NLane # 4 + N0 = N // NLane + K0 = K_pk // (KLane * KPack) + if gate_up: + src_reshaped = src.view( + experts_cnt, 2, N0, NLane, K0, KLane, KPack + ) # [E,2, N0, NLane ,K0, KLane, KPack] + src_reshaped = src_reshaped.permute( + 0, 2, 1, 4, 5, 3, 6 + ).contiguous() # [E, N0, 2, K0, KLane, NLane, KPack] + interleaved = src_reshaped.view(*src.shape) + else: + src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) + interleaved = ( + src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) + ) + # print("interleaved shape:", interleaved.shape) + return interleaved.contiguous().view(src_type) + + +def shuffle_scale_a16w4( + src: torch.Tensor, experts_cnt: int, gate_up: bool +) -> torch.Tensor: + n_experts, k_ = src.shape + n_ = n_experts // experts_cnt + # MXFP4 constants + K_Pack = 2 + N_Pack = 2 + N_Lane = 16 + K_Lane = 64 // N_Lane # 4 + + # Basic dimensions + K1 = k_ // K_Pack // K_Lane # k_ // 8 + N1 = n_ // N_Lane // N_Pack # n_ // 32 + real_k = 32 * k_ * K_Pack * K_Lane # 1x32 quant + assert real_k >= 256, f"K {real_k} must be larger than Tile_K(256)" + # print("src shape", src.shape) + # Reshape based on moe_kind + if gate_up: + # Reshape to: [E, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane] + shfl_scale = src.view(experts_cnt, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane) + # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] + shfl_scale = shfl_scale.permute(0, 2, 4, 6, 3, 5, 1).contiguous() + else: + # Reshape to: [E, K1, K_Pack, K_Lane, N1, N_Pack, N_Lane] + shfl_scale = src.view(experts_cnt, N1, N_Pack, N_Lane, K1, K_Pack, K_Lane) + # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] + shfl_scale = shfl_scale.permute(0, 1, 4, 6, 3, 5, 2).contiguous() + # print("shf_scale shape:", shfl_scale.shape) + return shfl_scale.view(*src.shape).contiguous() diff --git a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py new file mode 100644 index 0000000000..03d13d1846 --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os +import argparse +from pathlib import Path +import shutil +import re +from moe_cktile2stages_common import ( + kernelInstance, + get_gemm1_kernels_list, + get_gemm2_kernels_list, + get_heuristic_dispatch_template, +) +import sys + +this_dir = os.path.dirname(os.path.abspath(__file__)) +AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../") +if os.path.exists(os.path.join(AITER_CORE_DIR, "aiter_meta")): + AITER_CORE_DIR = os.path.join(AITER_CORE_DIR, "aiter/jit/utils") # pip install mode +else: + AITER_CORE_DIR = os.path.abspath( + f"{this_dir}/../../aiter/jit/utils" + ) # develop mode +sys.path.insert(0, AITER_CORE_DIR) + + +class cktile_moe_2stage_gemm_codegen: + def __init__( + self, + working_path, + ab_dtype, + acc_dtype, + c_dtype, + quant_type, + activation, + mul_routed_weight_stage, + istune=False, + ): + self.working_path = working_path + self.impl_path = os.path.join(working_path, "impl") + self.instances_path = os.path.join(working_path, "instances") + self.istune = istune + self.ab_dtype = ab_dtype.lower() + self.acc_dtype = acc_dtype.lower() + self.c_dtype = c_dtype.lower() + self.quant_type = quant_type + self.activation = activation + self.mul_routed_weight_stage = mul_routed_weight_stage + + def get_suffix(self, stage: int) -> str: + return ("_").join( + element + for element in [ + self.quant_type, + "MulRoutedWeight" if self.mul_routed_weight_stage == stage else "", + "" if (stage == 2) else self.activation, + ] + if element != "" + ) + + def gen_instance(self, k: kernelInstance): + INSTANCE_IMPL = f"""// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages_common.cuh" + +template +torch::Tensor +{k.name}( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias) +{{{{ + // The smallest kernel we have available. Works well for memory bound shapes. + int NumTokens = XQ.size(0); + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); + int E = WQ.size(0); + int KBatch = 1; + int stride_A = K; + int stride_B = K; + int stride_C = N / {3 - k.stage}; //gemm1 gate+up need / 2. + void *sorted_weights_ptr = topk_weight.has_value() ? topk_weight.value().data_ptr() : nullptr; + + {{INSTANCE_CONTENT}} + return Y; +}}}} + +""" + # default no quant + scaleGranA = "-1" + scaleGranB = "-1" + biasGran = "-1" + xptr = "nullptr" + wptr = "nullptr" + biasptr = "nullptr" + if k.QuantType == "per_tenser": + scaleGranA = "0" + scaleGranB = "0" + xptr = "static_cast(x_scale.value().data_ptr()[0])" + wptr = "static_cast(w_scale.value().data_ptr()[0])" + elif k.QuantType == "per_token": + scaleGranA = "1" + scaleGranB = "1" + xptr = "static_cast(x_scale.value().data_ptr())" + wptr = "static_cast(w_scale.value().data_ptr())" + elif k.QuantType == "1x32": + scaleGranA = "-1" + scaleGranB = "1, 32" + biasGran = "1" + xptr = "nullptr" + wptr = "static_cast(w_scale.value().data_ptr())" + biasptr = "static_cast(exp_bias.value().data_ptr())" + + INSTANCE_CONTENT = f"""auto per_a_scale_dev_ptr = ck_tile::FlatmmScalePointer<{scaleGranA}>{{{xptr}}}; + auto per_b_scale_dev_ptr = ck_tile::FlatmmScalePointer<{scaleGranB}>{{{wptr}}}; + auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<{biasGran}>{{{biasptr}}}; + ck_tile::MoeFlatmmHostArgs kernel_args{{ + reinterpret_cast(sorted_ids.data_ptr()), + sorted_weights_ptr, + reinterpret_cast(sorted_expert_ids.data_ptr()), + reinterpret_cast(max_token_ids.data_ptr()), + reinterpret_cast(XQ.data_ptr()), + reinterpret_cast(WQ.data_ptr()), + reinterpret_cast(Y.data_ptr()), + NumTokens, + E, + topk, + 1, // k_batch + M, + N, + K, + stride_A, + stride_B, + stride_C, + n_padded_zeros.value(), + k_padded_zeros.value(), + per_a_scale_dev_ptr, + per_b_scale_dev_ptr, + exp_bias_dev_ptr + }}; + using TileConfig = MoeFlatmmConfig; + // Run kernel instance. + auto stream_config = ck_stream_config{{at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream()}}; + moe_gemm, + AccDataType, + CDataType, + row_major, + col_major, + ck_tile::tuple<>, + row_major, + {"ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up" if k.stage == 1 else "ck_tile::MoeFlatmmKind::kFFN_gemm2"}, + ck_tile::element_wise::PassThrough + >(kernel_args, stream_config); +""" + + INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT=(INSTANCE_CONTENT)) + + Path(os.path.join(self.impl_path, f"{k.name}.cuh")).write_text( + INSTANCE_IMPL_str + ) + + INSTANCE_template = """// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "../impl/{name}.cuh" + +template torch::Tensor +{name}<{dtypes}>( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias); + +""" + + # if self.istune: + # INSTANCE_abI8_dBF16_eBF16 = INSTANCE_template.format( + # name=k.name, dtypes="I8, B16" + # ) + # Path( + # os.path.join(self.instances_path, f"{k.name}_abI8_dB16_eB16.cpp") + # ).write_text(INSTANCE_abI8_dBF16_eBF16) + # else: + def fill_template(name, a_type, b_type, acc_type, c_type): + nonlocal self + intsance = INSTANCE_template.format( + name=name, dtypes=f"{a_type}, {b_type}, {acc_type}, {c_type}" + ) + Path( + os.path.join( + self.instances_path, + f"{name}_a{a_type}_b{b_type}_acc{acc_type}_C{c_type}.cpp", + ) + ).write_text(intsance) + + if (k.QuantType == "1x32") and (self.ab_dtype in ["bf16", "fp16"]): + fill_template(k.name, self.ab_dtype, "pk_fp4", self.acc_dtype, self.c_dtype) + else: + for CDtype in ["bf16", "fp16"]: + for ABDtype in ["fp8"]: # "bf16", "fp16", + for AccDtype in ["float"]: + fill_template(k.name, ABDtype, ABDtype, AccDtype, CDtype) + # intsance = INSTANCE_template.format( + # name=k.name, dtypes=f"{ABDtype}, {AccDtype}, {CDtype}" + # ) + # Path( + # os.path.join( + # self.instances_path, + # f"{k.name}_ab{ABDtype}_acc{AccDtype}_C{CDtype}.cpp", + # ) + # ).write_text(intsance) + + """genarete heuristic dispatch""" + + def gen_heuristic_dispatch(self, tag, kernels_dict): + HEURISTIC_template = get_heuristic_dispatch_template(tag) + # print(HEURISTIC_template) + + def validate_and_format(template: str, mapping: dict) -> str: + # check all format element in dict. + str_mapping = {str(key): value.name for key, value in mapping.items()} + cleaned_template = template.replace("{{", "").replace("}}", "") + placeholders = re.findall(r"\{([^{}]*)\}", cleaned_template) + missing = [p for p in placeholders if p not in str_mapping] + # print(placeholders) + # print(str_mapping) + if missing: + raise KeyError(f"Missing keys in mapping: {missing}") + result = template + # for placeholder in placeholders: + # result = result.replace(placeholder, str_mapping[placeholder]) + # return result + return template.format(**{k: v for k, v in str_mapping.items()}) + + # create heuristic heirarchy + with open( + os.path.join(self.working_path, "moe_cktile2stages_heuristic_dispatch.h"), + "w", + ) as f: + f.write(validate_and_format(HEURISTIC_template, kernels_dict)) + # arch = get_gfx() + # inst_k = "32" if self.quant_type == "1x32" else ("128" if arch == "gfx950" else "64") + # f.write( + # HEURISTIC_template.format( + # inst_k=inst_k, + # suffix1 = self.get_suffix(1), + # suffix2 = self.get_suffix(2) + # ) + # ) + + """generate lookup.h linking MNK/datatype to specific instance""" + + def gen_lookup_dict(self, kernels_dict): + LOOKUP_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +// #ifdef USE_ROCM + +#define GENERATE_LOOKUP_TABLE(ABTYPE, ACCTYPE, CTYPE) \\ + { \\""" + + LOOKUP_template = """ + {{{MNK}, \\ + {kernel_name}}}, \\""" + + LOOKUP_end = """ + } + +// #endif // USE_ROCM +""" + with open( + os.path.join(self.working_path, "moe_cktile2stages_lookup.h"), "w" + ) as f: + f.write(LOOKUP_head) + for mnk, k in kernels_dict.items(): + print(":", k.name) + # if not tunning, tuned mnk = {stage, m, n, k} + if not self.istune and ( + isinstance(mnk, tuple) and (len(mnk) == 4) and mnk[1] > 0 + ): + f.write( + LOOKUP_template.format( + MNK="{" + + (", ").join(map(lambda x: str(x), list(mnk))) + + "}", + kernel_name=k.name, + ) + ) + # if tunning, mnk = -1,-2..... + elif self.istune and isinstance(mnk, int): + f.write(LOOKUP_template.format(MNK=mnk, kernel_name=k.name)) + f.write(LOOKUP_end) + + """generate manifest.h for instance header""" + + def gen_manifest_head(self, kernels_dict): + MAINFEST_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// #ifdef USE_ROCM + +#include + +#include +""" + MAINFEST_template = """ +template +torch::Tensor +{kernel_name}( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias); +""" + MAINFEST_end = """ + +// endif // USE_ROCM +""" + + with open( + os.path.join(self.working_path, "moe_cktile2stages_manifest.h"), "w" + ) as f: + f.write(MAINFEST_head) + for mnk, k in kernels_dict.items(): + f.write(MAINFEST_template.format(kernel_name=k.name)) + f.write(MAINFEST_end) + + """generate all instances and headers""" + + def gen_instances(self, tag, kernels_dict): + if os.path.exists(self.impl_path): + shutil.rmtree(self.impl_path) + os.mkdir(self.impl_path) + if os.path.exists(self.instances_path): + shutil.rmtree(self.instances_path) + os.mkdir(self.instances_path) + + for mnk, k in kernels_dict.items(): + self.gen_instance(k) + + self.gen_lookup_dict(kernels_dict) + self.gen_manifest_head(kernels_dict) + self.gen_heuristic_dispatch(tag, kernels_dict) + + +# def get_tune_dict(tune_dict_csv): +# tune_dict = default_kernels_dict +# if os.path.exists(tune_dict_csv): +# tune_df = pd.read_csv(tune_dict_csv) +# if torch.cuda.is_available(): +# gpu = torch.cuda.current_device() +# device_properties = torch.cuda.get_device_properties(gpu) +# cu_num = device_properties.multi_processor_count +# tune_df = tune_df[tune_df["cu_num"] == cu_num].reset_index() +# for i in range(len(tune_df)): +# M = tune_df.loc[i, "M"] +# N = tune_df.loc[i, "N"] +# K = tune_df.loc[i, "K"] +# kid = tune_df.loc[i, "kernelId"] +# tune_dict[(M, N, K)] = kernels_list[kid] +# return tune_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate ck_tile 2stage gemm instance." + ) + + # Add arguments + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated", + ) + + parser.add_argument( + "-f", + "--tune_file", + default="aiter/configs/a8w8_tuned_gemm.csv", + required=False, + help="tune_file include the result after run gemm_a8w8_tune.py", + ) + + parser.add_argument( + "-a", + "--a_dtype", + nargs="*", + required=False, + type=str, + choices=["f8", "i8", "f16", "b16"], + help="select input dtype", + ) + + parser.add_argument( + "-b", + "--b_dtype", + nargs="*", + required=False, + type=str, + choices=["f8", "i8", "f16", "b16", "i4"], + help="select weight dtype", + ) + + parser.add_argument( + "-c", + "--c_dtype", + default="b16", + required=False, + type=str, + choices=["f16", "b16"], + help="select out dtype", + ) + + parser.add_argument( + "-q", + "--quant_type", + default="per_tensor", + required=False, + type=str, + choices=[ + "per_tensor", + "per_token", + "1x32", + "128x128", + "no", + ], + help="select quant_type", + ) + + parser.add_argument( + "-act", + "--activation", + default="silu", + required=False, + type=str, + choices=["silu", "gelu"], + help="select activation", + ) + + parser.add_argument( + "-m", + "--mul_routed_weight_stage", + default=2, + required=False, + type=int, + choices=[1, 2], + help="select quant_type", + ) + + args = parser.parse_args() + + # # build all + # if args.b_dtype is None: + # # quanted moe + # b_quant_dtypes = ["f8"] + # c_dtypes = ["f16", "b16"] + # acts = ["silu"] #, "gelu"] + # general_quant_l = ["per_tensor", "per_token"] + # for b_dtype, c_dtype, act, quant in itertools.product( + # b_quant_dtypes, c_dtypes, acts, general_quant_l + # ): + # a_dtype = b_dtype + # codegen = cktile_moe_2stage_gemm_codegen( + # args.working_path, + # a_dtype, + # b_dtype, + # c_dtype, + # quant, + # act, + # ) + # codegen.generate_instance_and_lookUpTable() + + # # no-quant moe + # b_quant_dtypes = [ + # "f16", + # "b16", + # ] + # for ( + # b_dtype, + # act, + # ) in itertools.product(b_quant_dtypes, acts): + # c_dtype = a_dtype = b_dtype + + # codegen = cktile_moe_2stage_gemm_codegen( + # args.working_path, + # a_dtype, + # b_dtype, + # c_dtype, + # "no", + # act, + # ) + # codegen.generate_instance_and_lookUpTable() + # else: + + # single UT + # a_type = "fp8" + # b_type = "fp8" + # quant_type = "per_token" + + a_type = "bf16" + b_type = "fp4" + quant_type = "1x32" + + acc_type = "float" + c_type = "bf16" + act_type = "silu" + codegen = cktile_moe_2stage_gemm_codegen( + args.working_path, a_type, acc_type, c_type, quant_type, act_type, 2, False + ) + # gen all instances for gemm1 and gemm2 + _, gemm1_kernel_list = get_gemm1_kernels_list( + a_type, + b_type, + quant_type, + act_type, + False, + ) + tag, gemm2_kernel_list = get_gemm2_kernels_list( + a_type, + b_type, + quant_type, + "", + True, + ) + # merge gemm1/gemm2 dict with key = {stage, key} + kernel_dict_merge = { + **{(1, key): value for key, value in gemm1_kernel_list.items()}, + **{(2, key): value for key, value in gemm2_kernel_list.items()}, + } + # print(kernel_dict_merge) + codegen.gen_instances(tag, kernel_dict_merge) diff --git a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h new file mode 100644 index 0000000000..df9359d7bf --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h @@ -0,0 +1,74 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +// #include "moe_flatmm.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/moe_flatmm.hpp" +#include "py_itfs_common.h" +// #include +// #include +#include +#include +#include + +#include +#include +#include + +using MoeKernel = std::function, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional)>; +using ck_stream_config = ck_tile::stream_config; +using row_major = ck_tile::tensor_layout::gemm::RowMajor; +using col_major = ck_tile::tensor_layout::gemm::ColumnMajor; +using bf16 = ck_tile::bf16_t; +using fp16 = ck_tile::half_t; +using fp8 = ck_tile::fp8_t; +using pk_fp4 = ck_tile::pk_fp4_t; + +__attribute__((visibility("default"))) torch::Tensor +cktile_moe_gemm1(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m); + +__attribute__((visibility("default"))) torch::Tensor +cktile_moe_gemm2(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m); \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh new file mode 100644 index 0000000000..cd8d2724fa --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/moe_flatmm.hpp" +#include "moe_cktile2stages.h" +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +template +struct MoeFlatmmConfig +{ + static constexpr ck_tile::index_t M_Tile = M_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t K_Tile = K_Tile_; + + static constexpr ck_tile::index_t M_Warp = M_Warp_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_; + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = kBlockPerCu_; + static constexpr int TileParitionerGroupNum = 1; + static constexpr int TileParitionerM01 = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + +template +void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) +{ + using CodegenFlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits; // Preshuffle_ + + constexpr bool MXFP4_Pipeline = std::is_same_v; + + if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) + { + static_assert( + FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0, + "requires NRepeat is multiple of 2 for FFN_gemm1_gate_up"); + } + + using ComputeDataType = ADataType; + static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), + "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); + + using GemmPipelineProblem = ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using CodegenPipelineProblem = + std::conditional_t, + ck_tile::FlatmmPipelineProblem>; + + constexpr int BlockedXDLN_PerWarp = + (MXFP4_Pipeline || (moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)) + ? 2 + : 1; // determined by scale shuffle pattern + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using CodegenFlatmmPipeline = std::conditional_t< + MXFP4_Pipeline, + ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1, + ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1>; + + using FusedAct = + std::conditional_t; + + using Kernel = ck_tile::MoeFlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + // if(!Kernel::IsSupportedArgument(kargs)) + // { + // throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + // } + + // if(s.log_level_ > 0) + // { + // std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n" + // << "Shape: " << CodegenFlatmmShape::GetName() << "\n" + // << "problem: " << CodegenPipelineProblem::GetName() << "\n" + // << "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n" + // << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + // << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + // << std::endl; + // } + // + // if(s.flush_cache_) + // { + // std::cout << "Flushing cache..." << std::endl; + // static constexpr ck_tile::index_t APackedSize = + // std::is_same_v ? 2 : 1; + // static constexpr ck_tile::index_t BPackedSize = + // std::is_same_v ? 2 : 1; + + // ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + // moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK + // : args.NumTokens, + // args.K, + // args.stride_A, + // is_row_major(ALayout{}))); + // ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + // args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{}))); + + // const int outputN = + // moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N; + + // auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + // auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + // ck_tile::RotatingMemWrapper rotating_mem( + // kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + // rotating_mem.Print(); + + // auto run_flush_cache = [&]() { + // // flush icache + // ck_tile::flush_icache(); + // // rotating mem + // rotating_mem.Next(); + // // clear c mem + // if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2) + // hipGetErrorString(hipMemsetAsync( + // args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), + // s.stream_id_)); + // else if(args.k_batch > 1) + // hipGetErrorString( + // hipMemsetAsync(args.e_ptr, + // 0, + // args.NumTokens * args.TopK * outputN * sizeof(CDataType), + // s.stream_id_)); + // }; + // ave_time = ck_tile::launch_kernel_preprocess( + // s, + // run_flush_cache, + // ck_tile::make_kernel( + // Kernel{}, grids, blocks, 0, kargs)); + // } + // else + // { + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + // } + // return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +} \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu new file mode 100644 index 0000000000..73674ed146 --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages_common.cuh" +#include "moe_cktile2stages_lookup.h" +#include "moe_cktile2stages_manifest.h" +#include "py_itfs_common.h" +#include "moe_cktile2stages_heuristic_dispatch.h" +#include + +template +MoeKernel moe_dispatch(int M, int N, int K, int block_m) +{ + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. + + // static const auto lookup = [&] + // { + // return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(ABDataType, AccDataType, CDataType)}; + // }(); + + // // First check if this shape(M,N,K) is available in the direct lookup. + // auto it = lookup.find({M, N, K}); + // // If we found an optimal kernel, use it. + // if (it != lookup.end()) + // { + // return it->second; + // } + + // int padded_m = M; + // if (M > 1 && M <= 16) + // { + // padded_m = 16; + // } + // else if (M <= 16384) + // { + // padded_m = nextPow2(M); + // } + // else if (M <= 20480) + // { + // padded_m = 20480; + // } + // // Second check if this shape(padded_m,N,K) is available in the direct lookup. + // it = lookup.find({padded_m, N, K}); + // // If we found an optimal kernel, use it. + // if (it != lookup.end()) + // { + // return it->second; + // } + // Otherwise, use heuristics. + if(stage == 1) + { + return moe_gemm1_heuristic_dispatch( + M, N, K, block_m); + } + else + { + return moe_gemm2_heuristic_dispatch( + M, N, K, block_m); + } +} + +torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m) +{ + TORCH_CHECK(Y.dtype() == at::ScalarType::BFloat16 || Y.dtype() == at::ScalarType::Half, + "Out dtype only support BFloat16/Float16!"); + if(x_scale != std::nullopt && w_scale != std::nullopt) + { + TORCH_CHECK(x_scale.value().dtype() == w_scale.value().dtype(), + "Scales should have the same dtype!"); + } + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); + int MPerBlock = block_m.value(); + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); + at::hip::getCurrentHIPStream(); + // if (!XQ || !WQ || !sorted_ids || !sorted_expert_ids || !max_token_ids || !topk_weight) + // { + // std::cerr << "detect null ptr !" << std::endl; + // return; + // } + + if(XQ.dtype() == torch_fp8) + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + // if (Y.dtype() == at::ScalarType::BFloat16) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + } + else if((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && + (WQ.dtype() == torch_fp4x2)) // a16w4 + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + if(Y.dtype() == at::ScalarType::BFloat16) + { + moe_dispatch(M, N, K, MPerBlock)(XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias); + } + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} + +torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m) +{ + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); + int MPerBlock = block_m.value(); + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); + at::hip::getCurrentHIPStream(); + // if (!XQ. || !WQ || !sorted_ids || !sorted_expert_ids || !max_token_ids || !topk_weight) + // { + // std::cerr << "detect null ptr !" << std::endl; + // return; + // } + + if(XQ.dtype() == torch_fp8) + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + // if (Y.dtype() == at::ScalarType::BFloat16) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + } + else if((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && + (WQ.dtype() == torch_fp4x2)) // a16w4 + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + if(Y.dtype() == at::ScalarType::BFloat16) + { + moe_dispatch(M, N, K, MPerBlock)(XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias); + } + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py new file mode 100644 index 0000000000..f1be74edd8 --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py @@ -0,0 +1,448 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +from dataclasses import dataclass +import os +import sys + +this_dir = os.path.dirname(os.path.abspath(__file__)) +AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../") +if os.path.exists(os.path.join(AITER_CORE_DIR, "aiter_meta")): + AITER_CORE_DIR = os.path.join(AITER_CORE_DIR, "aiter/jit/utils") # pip install mode +else: + AITER_CORE_DIR = os.path.abspath( + f"{this_dir}/../../aiter/jit/utils" + ) # develop mode +sys.path.insert(0, AITER_CORE_DIR) + +from chip_info import get_gfx # noqa: E402 + + +@dataclass +class kernelInstance: + stage: int + BLOCK_SIZE: int + MPerBlock: int + NPerBlock: int + KPerBlock: int + WAVE_TILE_M: int + WAVE_TILE_N: int + WAVE_TILE_K: int + WAVE_MAP_M: int + WAVE_MAP_N: int + Block_Per_CU: int = 1 + MulRoutedWeight: bool = False + ActOP: str = "silu" + QuantType: str = "per_tensor" + + @property + def name(self) -> str: + return ("_").join( + element + for element in [ + f"moe_cktile2stages_gemm{self.stage}", + ("x").join( + map( + lambda x: str(x), + [ + self.BLOCK_SIZE, + self.MPerBlock, + self.NPerBlock, + self.KPerBlock, + ], + ) + ), + ("x").join(map(lambda x: str(x), [self.WAVE_MAP_M, self.WAVE_MAP_N])), + ("x").join( + map( + lambda x: str(x), + [self.WAVE_TILE_M, self.WAVE_TILE_N, self.WAVE_TILE_K], + ) + ), + str(self.Block_Per_CU) + "perCU", + self.QuantType, + "MulRoutedWeight" if self.MulRoutedWeight else "", + "" if (self.stage == 2) else self.ActOP, + ] + if element != "" + ) + + +# fmt: off +# gemm1 out:bf16/fp16 AB:fp8/i8 +a8w8_gemm1_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + # 0: kernelInstance( 1, 256, 32, 64, 256, 16, 16, 128, 1, 4,), + 1: kernelInstance( 1, 256, 32, 128, 128, 16, 16, 128, 1, 4,), + 2: kernelInstance( 1, 256, 64, 128, 128, 16, 16, 128, 1, 4,), + 4: kernelInstance( 1, 256, 64, 128, 256, 16, 16, 128, 1, 4,), + 4: kernelInstance( 1, 256, 128, 128, 128, 16, 16, 128, 1, 4,), + 5: kernelInstance( 1, 256, 128, 128, 128, 16, 16, 128, 1, 4,), + 6: kernelInstance( 1, 256, 256, 128, 128, 16, 16, 128, 1, 4,), +} + +# gemm2 out:bf16/fp16 AB:fp8/i8 +a8w8_gemm2_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + 0: kernelInstance( 2, 256, 32, 128, 256, 16, 16, 128, 1, 4,), + 1: kernelInstance( 2, 256, 64, 128, 256, 16, 16, 128, 1, 4,), + 2: kernelInstance( 2, 256, 128, 128, 128, 16, 16, 128, 1, 4,), + 3: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 128, 1, 4,), + 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 128, 1, 4,), +} + + +#a8w8 +a8w8_gemm1_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + # 0: kernelInstance( 1, 256, 32, 64, 256, 16, 16, 64, 1, 4,), + # 1: kernelInstance( 1, 256, 32, 64, 128, 16, 16, 64, 1, 4,), + # 2: kernelInstance( 1, 256, 64, 64, 256, 16, 16, 64, 2, 2,), + # 3: kernelInstance( 1, 256, 64, 64, 128, 16, 16, 64, 1, 4,), + 3: kernelInstance( 1, 256, 64, 128, 128, 16, 16, 64, 1, 4), + # 4: kernelInstance( 1, 256, 128, 64, 128, 16, 16, 64, 1, 4,), + # 5: kernelInstance( 1, 256, 128, 128, 128, 16, 16, 64, 1, 4,), + # 6: kernelInstance( 1, 256, 256, 128, 128, 16, 16, 64, 1, 4,), +} +a8w8_gemm2_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + # 0: kernelInstance( 2, 256, 32, 64, 256, 16, 16, 64, 1, 4,), + # 1: kernelInstance( 2, 256, 64, 64, 256, 16, 16, 64, 1, 4,), + # 2: kernelInstance( 2, 256, 128, 64, 128, 16, 16, 64, 1, 4,), + # 3: kernelInstance( 2, 256, 256, 64, 128, 16, 16, 64, 1, 4,), + # 4: kernelInstance( 2, 256, 64, 128, 256, 16, 16, 128, 1, 4,), + # 5: kernelInstance( 2, 256, 128, 128, 128, 16, 16, 64, 1, 4,), + # 6: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 64, 1, 4,), + # 7: kernelInstance( 2, 256, 32, 64, 128, 16, 16, 64, 1, 4,), + 8: kernelInstance( 2, 256, 64, 128, 128, 16, 16, 64, 1, 4,), +} + + +# gemm1 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm1_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N|| BlockPerCU| + 0: kernelInstance( 1, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 1, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 1, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 1, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 1, 256, 128, 256, 256, 16, 16, 32, 1, 4, 1,), +} +# gemm1 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm1_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N|| BlockPerCU| + 0: kernelInstance( 1, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 1, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 1, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 1, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 1, 256, 128, 256, 256, 16, 16, 32, 1, 4, 1,), +} +# gemm2 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm2_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| BlockPerCU| + 0: kernelInstance( 2, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 2, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 2, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 2, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 128, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 256, 256, 256, 16, 16, 32, 1, 4,), + # 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 32, 1, 4,), +} +# gemm2 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm2_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| BlockPerCU| + 0: kernelInstance( 2, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 2, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 2, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 2, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 128, 256, 128, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 256, 256, 256, 16, 16, 32, 1, 4,), + # 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 32, 1, 4,), +} + +# fmt: on +gemm1_kernels_dict = { + "a8w8_gfx950": a8w8_gemm1_kernels_list_gfx950, + "a8w8": a8w8_gemm1_kernels_list, + "a16w4_gfx950": a16w4_gemm1_kernels_list_gfx950, + "a16w4": a16w4_gemm1_kernels_list, +} + +gemm2_kernels_dict = { + "a8w8_gfx950": a8w8_gemm2_kernels_list_gfx950, + "a8w8": a8w8_gemm2_kernels_list, + "a16w4_gfx950": a16w4_gemm2_kernels_list_gfx950, + "a16w4": a16w4_gemm2_kernels_list, +} + + +a8w8_gfx950_heuristic_dispatch = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" + +template +MoeKernel moe_gemm1_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 32) + {{ + return {(1, 1)}; + }} + else if (block_m == 64) + {{ + return {(1, 2)}; + }} + //else if (block_m == 128) + //{{ + // return {(1, 4)}; + //}} + //else if (block_m == 256) + //{{ + // return {(1, 6)}; + //}} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_geem1 heuristic dispatch: ", + block_m); + }} +}} + +template +MoeKernel moe_gemm2_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 32) + {{ + return {(2, 0)}; + }} + else if (block_m == 64) + {{ + return {(2, 1)}; + }} + //else if (block_m == 128) + //{{ + // return {(2, 2)}; + //}} + //else if (block_m == 256) + //{{ + // return {(2, 3)}; + //}} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_gemm1 heuristic dispatch: ", + block_m); + }} +}} +""" + +a16w4_gfx950_heuristic_dispatch = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" + +template +MoeKernel moe_gemm1_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(1, 0)}; + }} + else if (block_m == 32) + {{ + return {(1, 1)}; + }} + else if (block_m == 64) + {{ + return {(1, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_geem1 heuristic dispatch: ", + block_m); + }} +}} + +template +MoeKernel moe_gemm2_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(2, 0)}; + }} + else if (block_m == 32) + {{ + return {(2, 1)}; + }} + else if (block_m == 64) + {{ + return {(2, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_gemm2 heuristic dispatch: ", + block_m); + }} +}} +""" + +a16w4_heuristic_dispatch = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" + +template +MoeKernel moe_gemm1_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(1, 0)}; + }} + else if (block_m == 32) + {{ + return {(1, 1)}; + }} + else if (block_m == 64) + {{ + return {(1, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_geem1 heuristic dispatch: ", + block_m); + }} +}} + +template +MoeKernel moe_gemm2_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(2, 0)}; + }} + else if (block_m == 32) + {{ + return {(2, 1)}; + }} + else if (block_m == 64) + {{ + return {(2, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_gemm2 heuristic dispatch: ", + block_m); + }} +}} +""" + +heuristic_dispatch_dict = { + "a8w8_gfx950": a8w8_gfx950_heuristic_dispatch, + # "a8w8": a8w8_gemm2_kernels_list, + "a16w4_gfx950": a16w4_gfx950_heuristic_dispatch, + "a16w4": a16w4_heuristic_dispatch, +} + + +bit8_list = ["f8", "i8", "fp8"] +bit16_list = ["b16", "f16", "bf16", "fp16"] +bit4_list = ["i4", "fp4x2", "fp4"] +QuantType_list = ["no", "per_tensor", "per_token", "per_1x128", "per_1x32"] + + +def get_gemm1_kernels_list( + Adtype: str, + Bdtype: str, + QuantType: str = "none", + ActOP: str = "silu", + MulRoutedWeight: bool = False, +) -> list: + arch = get_gfx() + if Adtype.lower() in bit8_list and Bdtype.lower() in bit8_list and Adtype == Bdtype: + if arch == "gfx950": + tag = "a8w8_gfx950" + else: + tag = "a8w8" + elif Adtype in bit16_list and Bdtype in bit4_list: + if arch == "gfx950": + tag = "a16w4_gfx950" + else: + tag = "a16w4" + else: + raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") + kernels_list = gemm1_kernels_dict[tag] + for id, kernel in kernels_list.items(): + kernel.MulRoutedWeight = MulRoutedWeight + kernel.ActOP = ActOP + kernel.QuantType = QuantType + # if tag == "a8w4": + # kernel.CDEElementOp = "MulABScaleWint4" + # elif tag == "a8w8blkscale": + # kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" + # elif tag == "a8w8" or tag == "a4w4": + # kernel.CDEElementOp = "MulABScale" + # elif tag == "a16w16": + # if MulRoutedWeight: + # kernel.CDEElementOp = "TypeCastExpertWeight" + # else: + # kernel.CDEElementOp = "TypeCast" + return tag, kernels_list + + +def get_gemm2_kernels_list( + Adtype: str, + Bdtype: str, + QuantType: str = "", + ActOP: str = "", + MulRoutedWeight: bool = True, +) -> list: + arch = get_gfx() + if Adtype in bit8_list and Bdtype in bit8_list and Adtype == Bdtype: + if arch == "gfx950": + tag = "a8w8_gfx950" + else: + tag = "a8w8" + elif Adtype in bit16_list and Bdtype in bit4_list: + if arch == "gfx950": + tag = "a16w4_gfx950" + else: + tag = "a16w4" + else: + raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") + kernels_list = gemm2_kernels_dict[tag] + for id, kernel in kernels_list.items(): + kernel.MulRoutedWeight = MulRoutedWeight + kernel.ActOP = "" + kernel.QuantType = QuantType + # if tag == "a8w4": + # kernel.CDEElementOp = "MulABScaleExpertWeightWin4" + # elif tag == "a8w8blkscale": + # kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" + # elif tag == "a8w8" or tag == "a4w4": + # kernel.CDEElementOp = "MulABScaleExpertWeight" + # elif tag == "a16w16": + # if MulRoutedWeight: + # kernel.CDEElementOp = "TypeCastExpertWeight" + # else: + # kernel.CDEElementOp = "TypeCast" + return tag, kernels_list + + +def get_heuristic_dispatch_template(tag): + if tag not in heuristic_dispatch_dict.keys(): + raise ValueError(f"Unsupported type for heuristic_dispatch: {tag}") + return heuristic_dispatch_dict[tag] diff --git a/csrc/include/aiter_enum.h b/csrc/include/aiter_enum.h index 0c35e8158f..15126c8cf6 100644 --- a/csrc/include/aiter_enum.h +++ b/csrc/include/aiter_enum.h @@ -6,7 +6,8 @@ enum class ActivationType : int { No = -1, Silu = 0, - Gelu + Gelu = 1, + Swiglu = 2, }; enum class QuantType : int { diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 74e3a9638e..fc420f1a6a 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -733,6 +733,43 @@ namespace py = pybind11; py::arg("quant_type") = 0, \ py::arg("activation") = 0); +#define MOE_CKTILE_2STAGES_PYBIND \ + m.def("cktile_moe_gemm1", \ + &cktile_moe_gemm1, \ + "cktile_moe_gemm1", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("Y"), \ + py::arg("sorted_ids"), \ + py::arg("sorted_expert_ids"), \ + py::arg("max_token_ids"), \ + py::arg("topk"), \ + py::arg("n_padded_zeros") = 0, \ + py::arg("k_padded_zeros") = 0, \ + py::arg("topk_weight") = std::nullopt, \ + py::arg("x_scale") = std::nullopt, \ + py::arg("w_scale") = std::nullopt, \ + py::arg("exp_bias") = std::nullopt, \ + py::arg("block_m") = 32); \ + \ + m.def("cktile_moe_gemm2", \ + &cktile_moe_gemm2, \ + "cktile_moe_gemm2", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("Y"), \ + py::arg("sorted_ids"), \ + py::arg("sorted_expert_ids"), \ + py::arg("max_token_ids"), \ + py::arg("topk"), \ + py::arg("n_padded_zeros") = 0, \ + py::arg("k_padded_zeros") = 0, \ + py::arg("topk_weight") = std::nullopt, \ + py::arg("x_scale") = std::nullopt, \ + py::arg("w_scale") = std::nullopt, \ + py::arg("exp_bias") = std::nullopt, \ + py::arg("block_m") = 32); + #define MHA_VARLEN_FWD_PYBIND \ m.def("mha_varlen_fwd", \ &aiter::torch_itfs::mha_varlen_fwd, \ @@ -1284,6 +1321,7 @@ namespace py = pybind11; .value("No", ActivationType::No) \ .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ + .value("Swiglu", ActivationType::Swiglu) \ .export_values(); \ pybind11::implicitly_convertible(); \ pybind11::implicitly_convertible(); @@ -1310,36 +1348,36 @@ namespace py = pybind11; py::arg("stride0"), \ py::arg("stride1")); -#define MLA_METADATA_PYBIND \ - m.def("get_mla_metadata_v1", \ - &get_mla_metadata_v1, \ - "get_mla_metadata_v1", \ - py::arg("seqlens_qo_indptr"), \ - py::arg("seqlens_kv_indptr"), \ - py::arg("num_heads_per_head_k"), \ - py::arg("num_heads_k"), \ - py::arg("is_causal"), \ - py::arg("work_metadata_ptrs"), \ - py::arg("work_info_set"), \ - py::arg("work_indptr"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("kv_granularity") = 16, \ - py::arg("max_seqlen_qo") = -1, \ - py::arg("uni_seqlen_qo") = -1, \ - py::arg("fast_mode") = true, \ - py::arg("topk") = -1); \ +#define MLA_METADATA_PYBIND \ + m.def("get_mla_metadata_v1", \ + &get_mla_metadata_v1, \ + "get_mla_metadata_v1", \ + py::arg("seqlens_qo_indptr"), \ + py::arg("seqlens_kv_indptr"), \ + py::arg("num_heads_per_head_k"), \ + py::arg("num_heads_k"), \ + py::arg("is_causal"), \ + py::arg("work_metadata_ptrs"), \ + py::arg("work_info_set"), \ + py::arg("work_indptr"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("kv_granularity") = 16, \ + py::arg("max_seqlen_qo") = -1, \ + py::arg("uni_seqlen_qo") = -1, \ + py::arg("fast_mode") = true, \ + py::arg("topk") = -1); \ m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant); -#define MLA_REDUCE_PYBIND \ - m.def("mla_reduce_v1", \ - &mla_reduce_v1, \ - "mla_reduce_v1", \ - py::arg("partial_output"), \ - py::arg("partial_lse"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("final_output"), \ - py::arg("final_lse") = std::nullopt); +#define MLA_REDUCE_PYBIND \ + m.def("mla_reduce_v1", \ + &mla_reduce_v1, \ + "mla_reduce_v1", \ + py::arg("partial_output"), \ + py::arg("partial_lse"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("final_output"), \ + py::arg("final_lse") = std::nullopt); diff --git a/csrc/pybind/moe_ck_2stages_pybind.cu b/csrc/pybind/moe_ck_2stages_pybind.cu index 6b237b1898..e720771df2 100644 --- a/csrc/pybind/moe_ck_2stages_pybind.cu +++ b/csrc/pybind/moe_ck_2stages_pybind.cu @@ -1,9 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "rocm_ops.hpp" #include "moe_ck.h" +#include "rocm_ops.hpp" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - MOE_CK_2STAGES_PYBIND; -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { MOE_CK_2STAGES_PYBIND; } diff --git a/csrc/pybind/moe_cktile_2stages_pybind.cu b/csrc/pybind/moe_cktile_2stages_pybind.cu new file mode 100644 index 0000000000..82947422ce --- /dev/null +++ b/csrc/pybind/moe_cktile_2stages_pybind.cu @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" +#include "rocm_ops.hpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { MOE_CKTILE_2STAGES_PYBIND; } diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 5fcc3571b8..6adb89fa11 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -11,115 +11,27 @@ from aiter.jit.utils.chip_info import get_gfx import argparse import pandas as pd +import numpy as np from aiter.fused_moe import ( fused_topk, - moe_sorting, fused_moe, torch_moe_stage1, torch_moe_stage2, - get_block_size_M, ) -from aiter.ops.shuffle import shuffle_weight +from aiter.ops.shuffle import ( + shuffle_weight, + shuffle_scale_a16w4, + shuffle_weight_a16w4, +) from aiter import ActivationType torch.int4 = getattr(torch, "int4", torch.uint32) torch.set_default_device("cuda") -def ck_moe_stage1( - hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - num_valid_ids, # [1] - w1_scale, - a1_scale, - dtype, - topk, - block_size=32, - Activation=ActivationType.Gelu, - quant_type=aiter.QuantType.No, - sorted_weights=None, # [max_num_tokens_padded] -): - token_num = hidden_states.shape[0] - D = w2.shape[-1] - # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - - if w1.dtype is torch.uint32: - D = D * 8 - - out = torch.empty((token_num, topk, D), dtype=dtype) - - aiter.ck_moe_stage1_fwd( - hidden_states, - w1, - w2, - sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - out, - topk, - "", - w1_scale, - a1_scale, - block_size, - sorted_weights, - quant_type, - Activation, - ) - - return out - - -def ck_moe_stage2( - hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - num_valid_ids, # [1] - w2_scale, - a2_scale, - dtype, - topk, - block_size=32, - Activation=ActivationType.Gelu, - quant_type=aiter.QuantType.No, - sorted_weights=None, # [max_num_tokens_padded] -): - token_num = hidden_states.shape[0] - D = w2.shape[1] - # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - - out = torch.zeros( - (token_num, D), - dtype=dtype, - device=hidden_states.device, - ) - aiter.ck_moe_stage2_fwd( - hidden_states, - w1, - w2, - sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - out, - topk, - "", - w2_scale, - a2_scale, - block_size, - sorted_weights, - quant_type, - Activation, - ) - return out - - @benchmark() def test_fmoe( dtype, @@ -134,33 +46,30 @@ def test_fmoe( WQDType, use_g1u1=False, doweight_stage1=False, + hidden_pad=0, + intermediate_pad=0, ): if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32: return torch_quant = aiter.get_torch_quant(qType) - torch_act = aiter.get_torch_act(actType) input = torch.randn((token, model_dim), dtype=dtype) if use_g1u1: w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype) + if hidden_pad != 0 and intermediate_pad != 0: + w1[:, :, -hidden_pad:] = 0 + w1[:, -intermediate_pad:, :] = 0 + w1[:, inter_dim - intermediate_pad : inter_dim, :] = 0 + exp_bias1 = torch.clamp(torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0) else: w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype) + exp_bias1 = torch.clamp(torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0) w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype) - + if hidden_pad != 0 and intermediate_pad != 0: + w2[:, :, -intermediate_pad:] = 0 + w2[:, -hidden_pad:, :] = 0 + exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) score = torch.randn((token, E), dtype=dtype) - # rand topk_weights, topk_ids = fused_topk(input, score, topk, True) - # sequence - # topk_ids_list = [[((i * topk) + j)% E for j in range(topk)] for i in range(token)] - # topk_ids = torch.tensor(topk_ids_list, device=topk_ids.device, dtype=topk_ids.dtype) - - M, _ = topk_ids.shape - - BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) - if qType == aiter.QuantType.per_128x128: - BLOCK_SIZE_M = 64 - sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( - topk_ids, topk_weights, E, model_dim, dtype, BLOCK_SIZE_M - ) if qType == aiter.QuantType.per_Tensor: w1_qt, w1_scale = aiter.pertoken_quant(w1.view(E, -1), quant_dtype=WQDType) @@ -207,36 +116,42 @@ def weight_per_128x128_quant(weight, quant_dtype): if qType != aiter.QuantType.per_1x32: w1_qt = w1_qt_aiter = w1_qt.view(w1.shape) w2_qt = w2_qt_aiter = w2_qt.view(w2.shape) - else: w1_qt = w1_qt_aiter = w1_qt.view(w1.shape[0], w1.shape[1], w1.shape[2] // 2) w2_qt = w2_qt_aiter = w2_qt.view(w2.shape[0], w2.shape[1], w2.shape[2] // 2) + # Quant-ing a if qType == aiter.QuantType.per_128x128: a1_qt, a1_scale = aiter.pertoken_quant( input.view(token, -1, 128), quant_dtype=AQDType ) a1_qt = a1_qt.view(token, model_dim) a1_scale = a1_scale.squeeze(-1) + elif ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and WQDType == dtypes.fp4x2 + ): # a16w4 + a1_qt = input.to(AQDType) + a1_scale = None else: a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) - # w1_scale = w1_scale.fill_(1) - # a1_scale = a1_scale.fill_(1) - out1_ref = torch_moe_stage1( - a1_qt, - w1_qt, - w2_qt, - topk_weights, - topk_ids, - dtype=dtype, - activation=actType, - quant_type=qType, - a1_scale=a1_scale, - w1_scale=w1_scale, - doweight=doweight_stage1, - ) + # bias dtype convert + if ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 + exp_bias1_aiter = exp_bias1.to(dtypes.fp32) + exp_bias2_aiter = exp_bias2.to(dtypes.fp32) + else: + exp_bias1_aiter = exp_bias1 = None + exp_bias2_aiter = exp_bias2 = None + # pre-shuffle + w1_scale_aiter = w1_scale + w2_scale_aiter = w2_scale if WQDType == torch.int4: # int4 w quant w1_qt_aiter = rearrange_4bit_elements( convert_int8_to_uint32_int4( @@ -248,67 +163,41 @@ def weight_per_128x128_quant(weight, quant_dtype): shuffle_weight(w2_qt_aiter, (16, 16), use_int4=True) ) ) + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) + elif ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 + w1_qt_aiter = shuffle_weight_a16w4(w1_qt_aiter, 16, True) + w1_scale_aiter = shuffle_scale_a16w4(w1_scale, E, True) + w2_qt_aiter = shuffle_weight_a16w4(w2_qt_aiter, 16, False) + w2_scale_aiter = shuffle_scale_a16w4(w2_scale, E, False) elif WQDType != dtypes.fp4x2: w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16)) w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) - # # ######################## ck stage 1 start ########### - # # a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) - # # out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) - # out1_ck, us = run_perftest( - # ck_moe_stage1, - # a1_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w1_scale, - # a1_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type=qType, - # sorted_weights=sorted_weights if doweight_stage1 else None, - # needTrace=True, - # ) - - # checkAllclose( - # out1_ref, - # out1_ck, - # msg=f"[perf] ck_moe_stage1:{us:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) - # ######################## stage 1 end ########### - - # if WQDType != torch.int4: - # # asm int4 2 stage not support yet - # if qType == aiter.QuantType.per_Tensor: - # a1_scale = a1_scale.view(1).repeat(token) - # w1_scale = w1_scale.view(E, 1).repeat(1, w1.shape[-2]) - - # out1_asm = torch.empty((token, topk, inter_dim), dtype=dtype) - # _, us = run_perftest( - # asm_stage1, - # a1_qt, - # shuffle_weight(w1_qt, (16, 16)), - # shuffle_weight(w2_qt, (16, 16)), - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # out1_asm, - # topk, - # kernelName="fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2", - # w1_scale=w1_scale, - # a1_scale=a1_scale, - # activation=actType, - # quant_type=qType, - # block_m=BLOCK_SIZE_M, - # ) - # checkAllclose( - # out1_ref, - # out1_asm, - # msg=f"[perf] asm_moe_stage1:{us:>8.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) + else: + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) + + # # ######################## stage 1 start ########### + out1_ref = torch_moe_stage1( + a1_qt, + w1_qt, + w2_qt, + topk_weights, + topk_ids, + dtype=dtype, + activation=actType, + quant_type=qType, + a1_scale=a1_scale, + w1_scale=w1_scale, + w1_bias=exp_bias1, + doweight=doweight_stage1, + ) # ######################## stage 2 start ########### if qType == aiter.QuantType.per_128x128: @@ -316,6 +205,13 @@ def weight_per_128x128_quant(weight, quant_dtype): out1_ref.view(token, -1, 128), quant_dtype=AQDType ) a2_scale = a2_scale.view(token, topk, -1) + elif ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 + a2_qt = out1_ref + a2_scale = None else: a2_qt, a2_scale = torch_quant(out1_ref, quant_dtype=AQDType) a2_qt = a2_qt.view(token, topk, -1) @@ -330,102 +226,43 @@ def weight_per_128x128_quant(weight, quant_dtype): quant_type=qType, w2_scale=w2_scale, a2_scale=a2_scale, + w2_bias=exp_bias2, doweight=not doweight_stage1, ) - # # out_ref = torch_moe( - # # input, - # # w1_qt, - # # w2_qt, - # # topk_weights, - # # topk_ids, - # # fc1_scale=w1_scale, - # # fc2_scale=w2_scale, - # # ) - # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") - - # out2_ck, us = run_perftest( - # ck_moe_stage2, - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale, - # a2_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # ) - - # checkAllclose( - # out2_ref, - # out2_ck, - # msg=f"[perf] ck_moe_stage2:{us:>8.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) - # ######################## stage 2 end ########### - # # ######################## fused 2 stage ######### - # out2_ck, us = run_perftest( - # ck_moe_2stages, - # input, - # w1_qt_aiter, - # w2_qt_aiter, - # topk_weights, - # topk_ids, - # quant_type=qType, - # fc1_scale=w1_scale, # [expert(local_expert:EP), inter_dim, 1] - # fc2_scale=w2_scale, # [expert(local_expert:EP), model_dim, 1] - # block_size=BLOCK_SIZE_M, - # activation=actType, - # doweight_stage1=doweight_stage1, - # ) - # checkAllclose( - # out2_ref, - # out2_ck, - # msg=f"ck_moe_2stages:{us:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) - - if dtype == dtypes.bf16: - out2_aiter, us_fuse = run_perftest( - fused_moe, - input, - w1_qt_aiter, - w2_qt_aiter, - topk_weights, - topk_ids, - w1_scale=fp4_utils.e8m0_shuffle( - w1_scale - ), # e8m0_shuffle will do nothing if it's a fp32 - w2_scale=fp4_utils.e8m0_shuffle(w2_scale), - quant_type=qType, - activation=actType, - doweight_stage1=doweight_stage1, - ) - - err = checkAllclose( - out2_ref, - out2_aiter, - msg=f"aiter_all_stages:{us_fuse:>8.2f} us......", - ) - - def calc_diff(x: torch.Tensor, y: torch.Tensor): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - logits_diff = calc_diff(out2_ref, out2_aiter) - assert logits_diff < 1e-3 + # ######################## stage 2 end ########### + out2_ck, us2 = run_perftest( + fused_moe, + input, + w1_qt_aiter, + w2_qt_aiter, + topk_weights, + topk_ids, + w1_scale=w1_scale_aiter, + w2_scale=w2_scale_aiter, + quant_type=qType, + activation=actType, + doweight_stage1=doweight_stage1, + intermediate_pad=intermediate_pad, + hidden_pad=hidden_pad, + bias1=exp_bias1_aiter, + bias2=exp_bias2_aiter, + num_iters=5, + num_warmup=2, + ) + err = checkAllclose( + out2_ref, + out2_ck, + msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) - return {"us": us_fuse, "err": err} + return {"us": us2, "err": err} l_dtype = ["bf16", "fp16"][:1] -l_dim = [(6144, 4096)] +# l_dim = [(6144, 4096)] +l_dim = [(7168, 256)] +# l_dim = [(3072, 3072)] l_tokenNum = [ 1, 3, @@ -446,9 +283,12 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 + (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 ] l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1] -l_doweight_stage1 = [False, True] +l_doweight_stage1 = [False, True][:1] +l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)][1:2] + parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, @@ -561,29 +401,54 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): if args.doweight_stage1 is not None: l_doweight_stage1 = [args.doweight_stage1] +df = [] for ( dtype, - act_type, (quant_type, aq_dtype, wq_dtype), (model_dim, inter_dim), doweight_stage1, -) in itertools.product(l_dtype, l_act, l_quant, l_dim, l_doweight_stage1): - df = [] - for m in l_tokenNum: - ret = test_fmoe( - dtype, - m, - model_dim, - inter_dim, - args.expert, - args.topk, - act_type, - quant_type, - aq_dtype, - wq_dtype, - use_g1u1=True, - doweight_stage1=doweight_stage1, - ) - df.append(ret) - df = pd.DataFrame(df) - aiter.logger.info(f"summary:\n{df}") +) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1): + if (quant_type, aq_dtype, wq_dtype) == ( + aiter.QuantType.per_1x32, + dtypes.bf16, + dtypes.fp4x2, + ): + for hidden_pad, intermediate_pad in l_hidden_intermediate_pad: + for m in l_tokenNum: + ret = test_fmoe( + dtype, + m, + model_dim, + inter_dim, + args.expert, + args.topk, + aiter.ActivationType.Swiglu, + quant_type, + aq_dtype, + wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + ) + df.append(ret) + else: + for act_type in l_act: + for m in l_tokenNum: + ret = test_fmoe( + dtype, + m, + model_dim, + inter_dim, + args.expert, + args.topk, + act_type, + quant_type, + aq_dtype, + wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + ) + df.append(ret) +df = pd.DataFrame(df) +aiter.logger.info(f"summary:\n{df}")