diff --git a/aiter/configs/model_configs/dsv3_fp4_tuned_fmoe.csv b/aiter/configs/model_configs/dsv3_fp4_tuned_fmoe.csv index e586713de6..f6aed8a961 100644 --- a/aiter/configs/model_configs/dsv3_fp4_tuned_fmoe.csv +++ b/aiter/configs/model_configs/dsv3_fp4_tuned_fmoe.csv @@ -1,7 +1,7 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,tflops,bw,_tag -256,1,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,13.307,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb14_fq,20.3%,7.6176,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.3%,20.9246,0,4.74,67614.8, -256,2,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.3539,flydsl_moe1_afp4_wfp4_bf16_t16x128x256_w3_kb7_bnt0_go_fq,18.7%,9.3549,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.3%,25.7088,0,7.71,55033.07, -256,4,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,21.7862,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb2_bnt0_go_fq,16.6%,11.8142,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.3%,33.6004,0,11.8,42108.94, +256,1,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,13.307,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb14_fp4,20.3%,7.6176,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.3%,20.9246,0,4.74,67614.8, +256,2,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.3539,flydsl_moe1_afp4_wfp4_bf16_t16x128x256_w3_kb7_bnt0_go_fp4,18.7%,9.3549,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.3%,25.7088,0,7.71,55033.07, +256,4,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,21.7862,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb2_bnt0_go_fp4,16.6%,11.8142,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.3%,33.6004,0,11.8,42108.94, 256,8,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,31.3946,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,18.8511,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.2%,50.2457,0,15.78,28160.88, 256,16,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,52.7618,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4,0.0%,30.7153,flydsl_moe2_afp4_wfp4_bf16_t16x256x256_atomic_sbm32,1.3%,83.4771,0,18.99,16952.38, 256,32,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,86.6761,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,49.9872,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.2%,136.6633,0,23.2,10357.42, @@ -15,9 +15,9 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,8192,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,273.7589,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_bnt0,0.0%,515.3227,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist,0.0%,789.0816,0,1028.73,2016.21, 256,16384,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,425.3491,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,1027.5233,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,1452.8724,0,1117.44,1216.29, 256,32768,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,669.5871999999999,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,2017.8465,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_sbm128,0.0%,2687.4337,0,1208.21,788.65, -256,1,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.6555,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb4_go_fq,23.1%,8.2174,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.6%,24.8729,0,7.97,113762.52, -256,2,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,22.244,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_go_fq,20.6%,14.059,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.7%,36.303,0,10.92,77944.67, -256,4,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,28.5005,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3_fq,19.5%,19.299,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,47.7995,0,16.58,59198.7, +256,1,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.6555,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb4_go_fp4,23.1%,8.2174,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.6%,24.8729,0,7.97,113762.52, +256,2,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,22.244,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_go_fp4,20.6%,14.059,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.7%,36.303,0,10.92,77944.67, +256,4,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,28.5005,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3_fp4,19.5%,19.299,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,47.7995,0,16.58,59198.7, 256,8,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,54.3584,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3,0.0%,30.4539,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,3.0%,84.8123,0,18.69,33364.91, 256,16,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,96.3987,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,51.0459,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,147.4446,0,21.51,19193.15, 256,32,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,163.055,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3,0.0%,89.1386,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_persist,2.9%,252.1936,0,25.15,11222.61, diff --git a/aiter/configs/model_configs/gptoss_fp8fp4_tuned_fmoe.csv b/aiter/configs/model_configs/gptoss_fp8fp4_tuned_fmoe.csv new file mode 100644 index 0000000000..86a5dea5fa --- /dev/null +++ b/aiter/configs/model_configs/gptoss_fp8fp4_tuned_fmoe.csv @@ -0,0 +1,15 @@ +cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,tflops,bw,_tag +256,512,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,214.3544,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w2_gui_fp8,0.0%,111.631,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2_persist,0.0%,325.9854,0,355.73,11131.16, +256,1024,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,235.2077,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_gui_fp8,0.0%,125.5088,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_atomic_bnt2,0.0%,360.7165,0,642.97,10072.5, +256,2048,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,312.5584,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_w2_bnt0_gui_fp8,0.0%,172.1029,flydsl_moe2_afp8_wfp4_bf16_t64x128x256_atomic_persist,0.0%,484.6613,0,957.07,7516.08, +256,4096,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,442.3352,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w2_bnt0_gui_fp8,0.0%,256.1523,flydsl_moe2_afp8_wfp4_bf16_t64x128x256_atomic_persist_sbm128,0.0%,698.4875,0,1328.17,5242.22, +256,8192,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,714.6281,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_bnt0_gui_fp8,0.0%,413.5452,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_atomic_persist_sbm128,0.0%,1128.1733,0,1644.63,3279.08, +256,16384,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,1356.5778,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w2_bnt0_gui,0.0%,731.3886,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_atomic_sbm128,0.0%,2087.9664,0,1777.26,1807.92, +256,32768,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,2474.7814,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_w2_bnt0_gui_fp8,0.0%,1348.5732,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_atomic_xcd4_persist,0.0%,3823.3546,0,1941.15,1026.81, +256,512,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,527.0715,cktile_a8w4_bm32,0.0,117.3402,cktile_a8w4_bm32,0.0,644.4117,0,0.0,0.0,flydsl_fallback +256,1024,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,286.4048,cktile_a8w4_bm32,0.0,142.6674,cktile_a8w4_bm32,0.0,429.0722,0,0.0,0.0,flydsl_fallback +256,2048,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,446.6267,cktile_a8w4_bm32,0.0,181.4069,cktile_a8w4_bm32,0.0,628.0336,0,0.0,0.0,flydsl_fallback +256,4096,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,786.0193,cktile_a8w4_bm32,0.0,275.9191,cktile_a8w4_bm32,0.0,1061.9384,0,0.0,0.0,flydsl_fallback +256,8192,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,1478.4233,cktile_a8w4_bm32,0.0,480.9397,cktile_a8w4_bm32,0.0,1959.363,0,0.0,0.0,flydsl_fallback +256,16384,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,2752.7649,cktile_a8w4_bm32,0.0,908.23,cktile_a8w4_bm32,0.0,3660.9949,0,0.0,0.0,flydsl_fallback +256,32768,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,5411.816400000001,cktile_a8w4_bm32,0.0,1750.1288,cktile_a8w4_bm32,0.0,7161.9452,0,0.0,0.0,flydsl_fallback diff --git a/aiter/configs/model_configs/gptoss_fp8fp4_untuned_fmoe.csv b/aiter/configs/model_configs/gptoss_fp8fp4_untuned_fmoe.csv new file mode 100644 index 0000000000..3bff7e7710 --- /dev/null +++ b/aiter/configs/model_configs/gptoss_fp8fp4_untuned_fmoe.csv @@ -0,0 +1,8 @@ +token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1 +512,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +1024,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +2048,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +4096,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +8192,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +16384,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +32768,3072,3072,128,4,ActivationType.Swiglu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 \ No newline at end of file diff --git a/aiter/configs/model_configs/kimik2_fp4_tuned_fmoe.csv b/aiter/configs/model_configs/kimik2_fp4_tuned_fmoe.csv index 8fea2ba24b..17ed61cbcb 100644 --- a/aiter/configs/model_configs/kimik2_fp4_tuned_fmoe.csv +++ b/aiter/configs/model_configs/kimik2_fp4_tuned_fmoe.csv @@ -1,129 +1,129 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,tflops,bw,_tag -256,1,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,13.1151,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb7_go_fq,17.6%,7.5545,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.5%,20.6696,0,4.26,102273.42, -256,2,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,15.2315,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb7_go_fq,16.8%,9.1849,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.5%,24.4164,0,7.21,86580.01, -256,4,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,20.0099,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_go_fq,15.8%,12.7265,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,1.3%,32.7364,0,10.76,64576.9, -256,8,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,29.5989,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,18.3012,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.2%,47.9001,0,14.71,44135.63, -256,16,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,45.079,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3,0.0%,27.7085,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.2%,72.7875,0,19.36,29047.2, -256,32,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,80.3136,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4_fq,17.6%,47.0468,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.2%,127.3604,0,22.13,16603.41, -256,64,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,119.6923,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4,0.0%,70.9479,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.1%,190.6402,0,29.57,11095.8, -256,128,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,120.0591,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3,0.0%,73.0398,flydsl_moe2_afp4_wfp4_bf16_t16x128x256_atomic_persist_sbm32,1.2%,193.0989,0,58.39,10961.65, -256,256,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,117.4334,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_fq,17.6%,76.8978,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.2%,194.3312,0,116.03,10906.3, -256,512,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,123.1326,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3,0.0%,83.0149,flydsl_moe2_afp4_wfp4_bf16_t16x128x256_atomic_sbm32,1.2%,206.1475,0,218.76,10307.86, -256,1024,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,120.0003,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_fq,17.3%,105.9035,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_persist,1.2%,225.9038,0,399.26,9455.13, -256,2048,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,134.7677,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w3,0.0%,156.8298,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce,0.0%,291.5975,0,618.62,7400.51, -256,4096,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,156.8228,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w4,0.0%,265.7605,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,422.5833,0,853.74,5210.83, -256,8192,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,264.6659,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,456.9264,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,721.5923,0,999.95,3173.66, -256,16384,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,412.2654,flydsl_moe1_afp4_wfp4_bf16_t128x64x256_w4_bnt0,0.0%,916.801,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_sbm128,0.0%,1329.0664,0,1085.81,1855.63, -256,32768,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,660.1809,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,1742.326,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_sbm128,0.0%,2402.5069,0,1201.34,1173.18, -256,1,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.0264,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb7_go_fq,12.7%,10.3185,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.6%,26.3449,0,6.69,160481.91, -256,2,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,19.9341,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_go_fq,15.4%,12.9207,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,32.8548,0,10.72,128684.44, -256,4,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,28.773,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w4,0.0%,17.2346,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.8%,46.0076,0,15.32,91896.65, -256,8,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,45.3057,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w2,0.0%,27.3685,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.8%,72.6742,0,19.39,58177.87, -256,16,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,83.6614,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4,0.0%,47.1953,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.7%,130.8567,0,21.54,32311.7, -256,32,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,158.164,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w2,0.0%,83.8976,flydsl_moe2_afp4_wfp4_bf16_t16x128x256_atomic_persist_sbm32,2.7%,242.0616,0,23.29,17468.89, -256,64,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,231.9115,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w2,0.0%,128.4022,flydsl_moe2_afp4_wfp4_bf16_t16x128x256_atomic_persist_sbm32,2.7%,360.3137,0,31.29,11737.65, -256,128,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,231.5877,flydsl_moe1_afp4_wfp4_bf16_t32x128x256,0.0%,132.2453,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_persist,2.7%,363.833,0,61.98,11627.89, -256,256,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,232.0612,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fq,17.5%,137.9798,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_persist_sbm64,2.7%,370.041,0,121.87,11440.26, -256,512,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,231.1873,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_fq,17.3%,144.8942,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_persist,2.7%,376.0815,0,239.83,11271.14, -256,1024,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,233.681,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_fq,17.3%,158.5219,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,2.7%,392.2029,0,459.94,10835.92, -256,2048,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,239.6427,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fq,17.3%,208.0474,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_persist,2.7%,447.6901,0,805.86,9542.09, -256,4096,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,299.92010000000005,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w3,0.0%,370.0656,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist_sbm128,0.3%,669.9857,0,1076.97,6441.84, -256,8192,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,496.2138,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_bnt0,0.0%,612.0706,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce,0.3%,1108.2844,0,1302.11,3973.73, -256,16384,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,760.59,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w2_bnt0,0.0%,1186.6336,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist_sbm128,0.3%,1947.2236,0,1482.22,2352.16, -256,32768,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,1384.0164,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w4_bnt0,0.0%,2304.5457,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist_sbm128,0.3%,3688.5621,0,1564.96,1337.24, -256,1,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.8796,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb7_go_fq,17.2%,10.604,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,27.4836,0,7.21,154233.43, -256,2,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,23.2686,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb2_go_fq,18.4%,13.5153,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.4%,36.7839,0,10.78,115238.23, -256,4,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,27.0697,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w4_fq,17.2%,18.7899,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.7%,45.8596,0,17.29,92433.31, -256,8,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,54.3209,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4,0.0%,30.2102,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.8%,84.5311,0,18.76,50147.7, -256,16,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,94.5343,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,52.56,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,147.0943,0,21.56,28819.69, -256,32,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,177.18269999999998,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3,0.0%,97.6084,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,2.9%,274.7911,0,23.08,15428.29, -256,64,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,243.3096,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4,0.0%,127.571,flydsl_moe2_afp4_wfp4_bf16_t16x256x256_atomic_sbm32,2.9%,370.8806,0,34.2,11432.91, -256,128,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,241.9776,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4,0.0%,129.9156,flydsl_moe2_afp4_wfp4_bf16_t16x256x256_atomic_sbm32,2.8%,371.8932,0,68.21,11405.48, -256,256,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,243.8308,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_fq,17.2%,136.2753,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_persist_sbm64,2.9%,380.1061,0,133.47,11166.29, -256,512,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,241.7796,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4_fq,17.2%,144.5642,flydsl_moe2_afp4_wfp4_bf16_t16x256x256_atomic_persist_sbm32,2.8%,386.3438,0,262.64,11000.25, -256,1024,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,247.4374,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_fq,17.3%,162.5128,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_persist_sbm64,2.9%,409.9502,0,495.03,10393.67, -256,2048,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,258.4258,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fq,17.2%,227.4551,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist,0.2%,485.8809,0,835.34,8814.73, -256,4096,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,330.43420000000003,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w2,0.0%,390.7126,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_sbm128,0.2%,721.1468,0,1125.64,6000.09, -256,8192,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,512.8136,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_bnt0,0.0%,657.8224,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist,0.2%,1170.636,0,1386.85,3771.48, -256,16384,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,788.8488,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w3_bnt0,0.0%,1289.7255,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist_sbm128,0.2%,2078.5743,0,1562.13,2208.82, -256,32768,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,1421.132,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_bnt0,0.0%,2576.7927,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist_sbm128,0.2%,3997.9247,0,1624.34,1236.52, -256,1,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,13.4816,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb14_bnt0_go_fq,19.9%,6.9284,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.3%,20.41,0,4.85,103843.99, -256,2,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,17.1983,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb7_fq,18.1%,8.8172,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.2%,26.0155,0,7.62,81469.79, -256,4,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,23.1761,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_go_fq,17.5%,13.282,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,1.3%,36.4581,0,10.87,58135.78, -256,8,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,31.2402,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,19.5754,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.2%,50.8156,0,15.6,41711.72, -256,16,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,54.3925,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4,0.0%,30.4551,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,1.3%,84.8476,0,18.69,24983.36, -256,32,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,93.0795,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3_fq,16.5%,55.1143,flydsl_moe2_afp4_wfp4_bf16_t16x128x256_atomic_persist_sbm32,1.2%,148.1938,0,21.4,14306.42, -256,64,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,128.7229,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2,0.0%,71.8716,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_persist,1.3%,200.5945,0,31.61,10572.63, -256,128,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,128.8881,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fq,16.6%,73.7141,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic,1.2%,202.6022,0,62.6,10474.65, -256,256,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,129.1083,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fq,17.0%,77.9378,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_persist,1.2%,207.0461,0,122.52,10263.12, -256,512,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,134.4496,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2,0.0%,85.7887,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_persist,1.2%,220.2383,0,230.36,9673.36, -256,1024,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,138.9433,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2,0.0%,117.7637,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_persist,1.2%,256.707,0,395.27,8342.02, -256,2048,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,147.6131,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3,0.0%,170.222,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce,0.0%,317.8351,0,638.5,6806.91, -256,4096,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,188.4282,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w4,0.0%,287.5067,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_sbm128,0.0%,475.9349,0,852.79,4638.27, -256,8192,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,289.1852,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,485.1722,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,774.3574,0,1048.29,2964.52, -256,16384,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,455.4862,flydsl_moe1_afp4_wfp4_bf16_t128x64x256_w4_bnt0,0.0%,997.0494,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,1452.5356,0,1117.7,1701.68, -256,32768,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,704.8525,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,1952.5578,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_sbm128,0.0%,2657.4103,0,1221.86,1062.72, -256,1,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,25.1935,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,6.9284,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.013,32.1219,0,0.0,0.0,flydsl_fallback -256,2,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,25.8511,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,8.8172,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0119,34.6683,0,0.0,0.0,flydsl_fallback -256,4,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,28.3615,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,15.584,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0128,43.9455,0,0.0,0.0,flydsl_fallback -256,8,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,32.4019,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,19.5754,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,51.9773,0,0.0,0.0,flydsl_fallback -256,16,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,56.6188,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,31.3397,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0129,87.9585,0,0.0,0.0,flydsl_fallback -256,32,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,106.3936,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,63.4416,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,169.8352,0,0.0,0.0,flydsl_fallback -256,64,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,140.4823,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,72.472,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,212.9543,0,0.0,0.0,flydsl_fallback -256,128,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,140.419,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,75.8299,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,216.2489,0,0.0,0.0,flydsl_fallback -256,256,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,141.8739,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,79.4328,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,221.3067,0,0.0,0.0,flydsl_fallback -256,512,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,143.0963,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,92.1974,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,235.2937,0,0.0,0.0,flydsl_fallback -256,1024,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,146.784,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,133.1199,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,279.9039,0,0.0,0.0,flydsl_fallback -256,2048,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,167.7902,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,237.8075,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,405.5977,0,0.0,0.0,flydsl_fallback -256,4096,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,195.5058,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,454.2251,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,649.7309,0,0.0,0.0,flydsl_fallback -256,8192,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,289.1852,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,857.905,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,1147.0902,0,0.0,0.0,flydsl_fallback -256,16384,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,464.4265,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1701.4992,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,2165.9257,0,0.0,0.0,flydsl_fallback -256,32768,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,704.8525,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,3454.8083,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,4159.6608,0,0.0,0.0,flydsl_fallback -256,1,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,25.2469,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,8.6968,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0248,33.9437,0,0.0,0.0,flydsl_fallback -256,2,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,27.7751,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,13.0489,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0282,40.824,0,0.0,0.0,flydsl_fallback -256,4,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,32.8017,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,19.7088,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0271,52.5105,0,0.0,0.0,flydsl_fallback -256,8,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,55.880300000000005,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,32.051,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0267,87.9313,0,0.0,0.0,flydsl_fallback -256,16,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,105.0341,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,53.7335,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0279,158.7676,0,0.0,0.0,flydsl_fallback -256,32,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,198.7704,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,96.5701,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0291,295.3405,0,0.0,0.0,flydsl_fallback -256,64,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,260.5895,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,126.2447,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0285,386.8342,0,0.0,0.0,flydsl_fallback -256,128,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,260.1268,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,130.1134,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0289,390.2402,0,0.0,0.0,flydsl_fallback -256,256,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,261.8685,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,134.7762,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0287,396.6447,0,0.0,0.0,flydsl_fallback -256,512,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,269.4169,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,162.2892,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0282,431.7061,0,0.0,0.0,flydsl_fallback -256,1024,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,274.1024,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,277.6658,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0285,551.7682,0,0.0,0.0,flydsl_fallback -256,2048,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,302.5024,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,351.6368,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,654.1392,0,0.0,0.0,flydsl_fallback -256,4096,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,466.4292,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,611.4182,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,1077.8474,0,0.0,0.0,flydsl_fallback -256,8192,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,534.4255,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1009.7773,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,1544.2028,0,0.0,0.0,flydsl_fallback -256,16384,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,865.0261,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1973.5537,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0285,2838.5798,0,0.0,0.0,flydsl_fallback -256,32768,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1588.1829,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,3805.4252,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0285,5393.6081,0,0.0,0.0,flydsl_fallback -256,1,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,24.083,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,7.5545,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0152,31.6375,0,0.0,0.0,flydsl_fallback -256,2,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,24.5225,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,9.1849,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0146,33.7074,0,0.0,0.0,flydsl_fallback -256,4,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,26.7832,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,14.3059,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0137,41.0891,0,0.0,0.0,flydsl_fallback -256,8,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,30.4611,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,19.1838,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,49.6449,0,0.0,0.0,flydsl_fallback -256,16,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,50.0482,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,29.6543,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,79.7025,0,0.0,0.0,flydsl_fallback -256,32,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,91.6176,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,49.5208,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,141.1384,0,0.0,0.0,flydsl_fallback -256,64,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,132.9996,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,71.7614,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0114,204.761,0,0.0,0.0,flydsl_fallback -256,128,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,133.5936,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,74.0606,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,207.6542,0,0.0,0.0,flydsl_fallback -256,256,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,134.2042,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,78.5686,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,212.7728,0,0.0,0.0,flydsl_fallback -256,512,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,136.536,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,88.8643,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,225.4003,0,0.0,0.0,flydsl_fallback -256,1024,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,139.5184,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,117.4777,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,256.9961,0,0.0,0.0,flydsl_fallback -256,2048,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,148.86509999999998,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,217.4645,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,366.3296,0,0.0,0.0,flydsl_fallback -256,4096,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,171.4984,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,401.5359,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0116,573.0343,0,0.0,0.0,flydsl_fallback -256,8192,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,264.6659,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,758.2827,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,1022.9486,0,0.0,0.0,flydsl_fallback -256,16384,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,441.1082,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1455.2578,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0116,1896.366,0,0.0,0.0,flydsl_fallback -256,32768,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,660.1809,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,2869.8423,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0116,3530.0232,0,0.0,0.0,flydsl_fallback -256,1,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,24.5993,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,7.3376,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0244,31.9369,0,0.0,0.0,flydsl_fallback -256,2,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,26.5764,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,14.5195,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0298,41.0959,0,0.0,0.0,flydsl_fallback -256,4,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,30.4475,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,19.0794,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0272,49.5269,0,0.0,0.0,flydsl_fallback -256,8,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,49.6576,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,29.8515,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0269,79.5091,0,0.0,0.0,flydsl_fallback -256,16,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,91.2137,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,48.9993,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.028,140.213,0,0.0,0.0,flydsl_fallback -256,32,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,174.01690000000002,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,86.3919,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0276,260.4088,0,0.0,0.0,flydsl_fallback -256,64,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,255.9308,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,125.1029,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,381.0337,0,0.0,0.0,flydsl_fallback -256,128,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,255.3112,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,127.4089,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0272,382.7201,0,0.0,0.0,flydsl_fallback -256,256,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,258.5313,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,132.2283,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0274,390.7596,0,0.0,0.0,flydsl_fallback -256,512,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,260.838,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,146.7813,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,407.6193,0,0.0,0.0,flydsl_fallback -256,1024,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,267.26160000000004,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,229.4312,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0274,496.6928,0,0.0,0.0,flydsl_fallback -256,2048,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,286.02410000000003,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,342.0605,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,628.0846,0,0.0,0.0,flydsl_fallback -256,4096,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,458.2104,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,567.252,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,1025.4624,0,0.0,0.0,flydsl_fallback -256,8192,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,514.8109,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,938.2635,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,1453.0744,0,0.0,0.0,flydsl_fallback -256,16384,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,824.8789999999999,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1846.1768,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,2671.0558,0,0.0,0.0,flydsl_fallback -256,32768,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1448.6479,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,3504.3122,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,4952.9601,0,0.0,0.0,flydsl_fallback +256,1,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,13.2452,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb7_go_fp4,16.9%,6.6844,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.6%,19.9296,0,4.42,106070.91, +256,2,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,15.8944,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb4_go_fp4,17.2%,8.4755,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.3%,24.3699,0,7.23,86745.22, +256,4,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,21.3039,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_fp4,16.1%,12.4074,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,1.3%,33.7113,0,10.45,62709.4, +256,8,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,29.8852,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4,0.0%,17.8471,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.2%,47.7323,0,14.76,44290.79, +256,16,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,44.7944,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4,0.0%,27.4087,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.2%,72.2031,0,19.52,29282.31, +256,32,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,81.51140000000001,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3,0.0%,44.797,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.2%,126.3084,0,22.32,16741.7, +256,64,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,118.8141,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2,0.0%,67.0329,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,1.1%,185.847,0,30.33,11381.97, +256,128,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,118.5136,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4,0.0%,69.4259,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.2%,187.9395,0,59.99,11262.57, +256,256,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,115.8638,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_fp4,17.6%,74.2967,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,1.2%,190.1605,0,118.58,11145.5, +256,512,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,117.5612,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_fp4,17.0%,80.6277,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.2%,198.1889,0,227.55,10721.79, +256,1024,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,119.968,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,17.3%,102.033,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2_sbm64,1.1%,222.001,0,406.28,9621.35, +256,2048,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,132.5067,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4,0.0%,157.7576,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_bnt2,0.0%,290.2643,0,621.46,7434.5, +256,4096,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,156.74679999999998,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w2,0.0%,267.2572,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,424.004,0,850.88,5193.37, +256,8192,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,250.0777,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w3_bnt0_xcd4,0.0%,463.8763,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_sbm128,0.0%,713.954,0,1010.65,3207.62, +256,16384,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,366.0242,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_bnt0_xcd4_fp4,17.3%,926.6298,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce,0.0%,1292.654,0,1116.39,1907.9, +256,32768,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,631.8818,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_bnt0_xcd4_fp4,17.3%,1751.4067,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist,0.0%,2383.2885,0,1211.02,1182.64, +256,1,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.354,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb4_go_fp4,12.7%,7.2729,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,2.6%,23.6269,0,7.46,178943.49, +256,2,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,21.5812,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_fp4,15.4%,12.8273,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,3.0%,34.4085,0,10.24,122873.75, +256,4,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,28.0247,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3,0.0%,17.2081,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,2.9%,45.2328,0,15.58,93470.77, +256,8,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,44.514500000000005,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_xcd4,0.0%,27.0211,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,2.8%,71.5356,0,19.7,59103.87, +256,16,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,81.495,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_xcd4,0.0%,44.4102,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2_xcd4,2.7%,125.9052,0,22.39,33582.43, +256,32,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,154.3846,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3,0.0%,80.4195,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_xcd4_persist,2.7%,234.8041,0,24.01,18008.83, +256,64,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,228.1004,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,17.2%,118.3423,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,2.8%,346.4427,0,32.54,12207.6, +256,128,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,228.1496,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3,0.0%,121.3876,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,2.7%,349.5372,0,64.51,12103.46, +256,256,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,228.5579,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,17.5%,124.7,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,2.7%,353.2579,0,127.66,11983.78, +256,512,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,232.2845,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,17.3%,136.0064,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist_sbm64,2.7%,368.2909,0,244.9,11509.57, +256,1024,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,233.5376,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4_fp4,17.3%,146.6693,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,2.7%,380.2069,0,474.45,11177.8, +256,2048,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,238.474,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,17.3%,206.1031,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_bnt2_persist,2.7%,444.5771,0,811.51,9608.9, +256,4096,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,292.4502,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w4,0.0%,377.921,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4_persist_sbm128,0.3%,670.3712,0,1076.35,6438.13, +256,8192,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,459.2863,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0_fp4,17.3%,626.7223,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist,0.3%,1086.0086,0,1328.82,4055.23, +256,16384,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,691.0762,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,1171.9833,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4_sbm128,0.3%,1863.0595,0,1549.18,2458.42, +256,32768,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1208.457,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0_fp4,17.3%,2227.9499,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4,0.3%,3436.4069,0,1679.79,1435.37, +256,1,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,17.8501,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb7_fp4,16.7%,7.6621,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,2.4%,25.5122,0,7.77,166151.49, +256,2,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,22.8551,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb7_fp4,18.9%,13.1521,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,2.7%,36.0072,0,11.01,117724.0, +256,4,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,29.4262,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3,0.0%,19.3881,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,2.8%,48.8143,0,16.24,86838.38, +256,8,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,50.5137,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4_fp4,18.7%,29.363,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,2.7%,79.8767,0,19.85,53069.8, +256,16,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,95.7203,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,49.2937,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,2.8%,145.014,0,21.87,29233.13, +256,32,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,171.3749,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_xcd4,0.0%,89.5678,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,2.8%,260.9427,0,24.3,16247.08, +256,64,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,236.2525,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_xcd4,0.0%,120.7332,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_sbm64,2.8%,356.9857,0,35.53,11877.91, +256,128,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,234.1601,flydsl_moe1_afp4_wfp4_bf16_t64x64x256_w4_xcd4_fp4,17.1%,122.0701,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_sbm64,2.9%,356.2302,0,71.21,11906.97, +256,256,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,236.6952,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_xcd4_fp4,17.2%,127.4914,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist_sbm64,2.9%,364.1866,0,139.31,11654.39, +256,512,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,242.1832,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,17.2%,134.7446,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,2.8%,376.9278,0,269.2,11275.05, +256,1024,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,246.3426,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,17.3%,153.0524,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,2.9%,399.395,0,508.11,10668.36, +256,2048,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,251.0842,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,17.2%,233.9762,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_atomic_bnt2_persist,2.8%,485.0604,0,836.75,8829.64, +256,4096,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,316.9738,flydsl_moe1_afp4_wfp4_bf16_t128x128x256,0.0%,400.1827,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.2%,717.1565,0,1131.9,6033.48, +256,8192,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,473.8717,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_bnt0_fp4,17.2%,670.9802,flydsl_moe2_afp4_wfp4_bf16_t64x128x256_reduce_persist,0.2%,1144.8519,0,1418.09,3856.42, +256,16384,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,731.1314,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_bnt0_fp4,17.3%,1313.4823,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4_persist,0.2%,2044.6137,0,1588.07,2245.5, +256,32768,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1306.3543,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w2_bnt0_fp4,17.3%,2514.8871,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_xcd4_persist,0.2%,3821.2414,0,1699.45,1293.69, +256,1,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,13.3809,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb14_go_fp4,20.1%,6.8161,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.2%,20.197,0,4.91,104939.14, +256,2,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.4835,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb4_go_fp4,17.6%,8.4904,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,1.0%,24.9739,0,7.94,84867.69, +256,4,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,22.3365,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb2_go_fp4,17.3%,13.1164,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,1.2%,35.4529,0,11.18,59784.12, +256,8,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,30.8842,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w4,0.0%,19.279,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,1.3%,50.1632,0,15.8,42254.21, +256,16,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,54.5127,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3,0.0%,29.3089,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.3%,83.8216,0,18.91,25289.17, +256,32,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,93.3433,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,53.0394,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.3%,146.3827,0,21.66,14483.42, +256,64,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,129.68970000000002,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4_xcd4,0.0%,68.2803,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,1.2%,197.97,0,32.03,10712.79, +256,128,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,126.7838,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w4_xcd4_fp4,16.6%,70.9723,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2_persist,1.2%,197.7561,0,64.14,10731.33, +256,256,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,127.3984,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_xcd4_fp4,17.0%,73.9814,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.2%,201.3798,0,125.97,10551.9, +256,512,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,130.7355,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,17.1%,82.6416,flydsl_moe2_afp4_wfp4_bf16_t32x256x256_atomic_bnt2,1.2%,213.3771,0,237.77,9984.41, +256,1024,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,133.9164,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w2_fp4,17.4%,115.6845,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic_bnt2,1.2%,249.6009,0,406.52,8579.51, +256,2048,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,139.5789,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w3_fp4,17.2%,169.6701,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_bnt2_persist,0.0%,309.249,0,656.23,6995.9, +256,4096,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,177.8189,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_w2,0.0%,282.0472,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,459.8661,0,882.59,4800.34, +256,8192,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,284.6439,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_xcd4,0.0%,486.7326,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,771.3765,0,1052.34,2975.97, +256,16384,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,440.0206,flydsl_moe1_afp4_wfp4_bf16_t128x128x256_bnt0_xcd4,0.0%,979.1867,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,1419.2073,0,1143.95,1741.65, +256,32768,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,700.9508,flydsl_moe1_afp4_wfp4_bf16_t64x128x256_w4_bnt0_xcd4_fp4,17.3%,1957.461,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce,0.0%,2658.4118,0,1221.4,1062.32, +256,1,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,24.6248,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,6.6844,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0156,31.3092,0,0.0,0.0,flydsl_fallback +256,2,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,24.6277,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,8.4755,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0133,33.1032,0,0.0,0.0,flydsl_fallback +256,4,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,26.1829,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,14.3766,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0138,40.5595,0,0.0,0.0,flydsl_fallback +256,8,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,30.2074,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,18.9994,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,49.2068,0,0.0,0.0,flydsl_fallback +256,16,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,50.1851,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,28.8427,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0119,79.0278,0,0.0,0.0,flydsl_fallback +256,32,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,91.9274,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,56.4536,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0119,148.381,0,0.0,0.0,flydsl_fallback +256,64,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,133.52890000000002,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,72.2503,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0114,205.7792,0,0.0,0.0,flydsl_fallback +256,128,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,134.37,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,75.4038,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0113,209.7738,0,0.0,0.0,flydsl_fallback +256,256,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,135.7072,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,80.5166,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,216.2238,0,0.0,0.0,flydsl_fallback +256,512,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,136.8202,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,88.9883,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,225.8085,0,0.0,0.0,flydsl_fallback +256,1024,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,141.0647,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,117.1939,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,258.2586,0,0.0,0.0,flydsl_fallback +256,2048,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,150.5248,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,214.4916,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,365.0164,0,0.0,0.0,flydsl_fallback +256,4096,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,176.7558,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,403.0595,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,579.8153,0,0.0,0.0,flydsl_fallback +256,8192,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,267.11560000000003,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,759.2218,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0117,1026.3374,0,0.0,0.0,flydsl_fallback +256,16384,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,437.7072,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1455.2831,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0116,1892.9903,0,0.0,0.0,flydsl_fallback +256,32768,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,663.3669,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,2875.3131,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0116,3538.68,0,0.0,0.0,flydsl_fallback +256,1,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,24.539,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,7.2729,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0259,31.8119,0,0.0,0.0,flydsl_fallback +256,2,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,26.6509,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,14.035,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0296,40.6859,0,0.0,0.0,flydsl_fallback +256,4,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,31.2822,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,19.072,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0281,50.3542,0,0.0,0.0,flydsl_fallback +256,8,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,49.7396,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,29.5203,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0279,79.2599,0,0.0,0.0,flydsl_fallback +256,16,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,92.3538,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,48.1233,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0279,140.4771,0,0.0,0.0,flydsl_fallback +256,32,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,174.5219,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,85.8354,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0276,260.3573,0,0.0,0.0,flydsl_fallback +256,64,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,257.9019,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,125.3541,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,383.256,0,0.0,0.0,flydsl_fallback +256,128,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,256.3908,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,127.7442,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,384.135,0,0.0,0.0,flydsl_fallback +256,256,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,260.6848,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,133.543,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,394.2278,0,0.0,0.0,flydsl_fallback +256,512,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,263.3216,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,149.5248,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,412.8464,0,0.0,0.0,flydsl_fallback +256,1024,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,273.1676,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,224.9774,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0274,498.145,0,0.0,0.0,flydsl_fallback +256,2048,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,289.69550000000004,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,339.3983,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0272,629.0938,0,0.0,0.0,flydsl_fallback +256,4096,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,452.9873,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,569.9896,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,1022.9769,0,0.0,0.0,flydsl_fallback +256,8192,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,516.1552,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,938.8917,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,1455.0469,0,0.0,0.0,flydsl_fallback +256,16384,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,839.7145,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1859.6769,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,2699.3914,0,0.0,0.0,flydsl_fallback +256,32768,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1469.5306,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,3517.7145,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0273,4987.2451,0,0.0,0.0,flydsl_fallback +256,1,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,25.0451,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,7.6621,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0241,32.7072,0,0.0,0.0,flydsl_fallback +256,2,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,27.1551,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,14.1879,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,41.343,0,0.0,0.0,flydsl_fallback +256,4,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,32.0801,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,19.8641,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0281,51.9442,0,0.0,0.0,flydsl_fallback +256,8,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,55.6951,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,31.6528,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0265,87.3479,0,0.0,0.0,flydsl_fallback +256,16,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,105.1921,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,53.563,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.028,158.7551,0,0.0,0.0,flydsl_fallback +256,32,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,199.5216,moe_ck2stages_gemm1_256x32x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,97.0299,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0288,296.5515,0,0.0,0.0,flydsl_fallback +256,64,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,260.8307,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,125.6864,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,386.5171,0,0.0,0.0,flydsl_fallback +256,128,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,266.3821,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,128.3034,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0285,394.6855,0,0.0,0.0,flydsl_fallback +256,256,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,265.6527,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,135.5647,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0286,401.2174,0,0.0,0.0,flydsl_fallback +256,512,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,269.6034,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,162.5341,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0283,432.1375,0,0.0,0.0,flydsl_fallback +256,1024,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,280.10200000000003,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,275.7737,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,555.8757,0,0.0,0.0,flydsl_fallback +256,2048,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,311.42830000000004,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,349.9696,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,661.3979,0,0.0,0.0,flydsl_fallback +256,4096,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,460.8264,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,617.8159,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,1078.6423,0,0.0,0.0,flydsl_fallback +256,8192,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,539.202,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,995.9833,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0284,1535.1853,0,0.0,0.0,flydsl_fallback +256,16384,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,876.0268,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1986.4763,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0285,2862.5031,0,0.0,0.0,flydsl_fallback +256,32768,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1611.5569,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,3894.0732,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0285,5505.6301,0,0.0,0.0,flydsl_fallback +256,1,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,23.9613,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,6.8161,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0124,30.7774,0,0.0,0.0,flydsl_fallback +256,2,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,24.9598,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,8.4904,moe_ck2stages_gemm2_256x32x128x128_1x4_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0102,33.4502,0,0.0,0.0,flydsl_fallback +256,4,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,28.0817,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,14.2802,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0133,42.3619,0,0.0,0.0,flydsl_fallback +256,8,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,31.893,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,19.8686,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0125,51.7616,0,0.0,0.0,flydsl_fallback +256,16,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,55.979400000000005,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,31.3644,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,87.3438,0,0.0,0.0,flydsl_fallback +256,32,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,105.7492,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,54.6225,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.013,160.3717,0,0.0,0.0,flydsl_fallback +256,64,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,138.129,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,70.2826,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,208.4116,0,0.0,0.0,flydsl_fallback +256,128,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,139.39800000000002,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,76.9762,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.012,216.3742,0,0.0,0.0,flydsl_fallback +256,256,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,140.2029,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,78.7722,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0124,218.9751,0,0.0,0.0,flydsl_fallback +256,512,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,142.1878,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,92.4496,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,234.6374,0,0.0,0.0,flydsl_fallback +256,1024,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,148.3272,moe_ck2stages_gemm1_64x32x32x128_1x1_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,133.2883,moe_ck2stages_gemm2_64x32x32x128_1x1_MulABScaleExpertWeightShuffled_v1_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0122,281.6155,0,0.0,0.0,flydsl_fallback +256,2048,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,163.71689999999998,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,237.3456,moe_ck2stages_gemm2_64x64x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,401.0625,0,0.0,0.0,flydsl_fallback +256,4096,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,194.9509,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,451.9496,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,646.9005,0,0.0,0.0,flydsl_fallback +256,8192,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,288.8715,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,857.3649,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,1146.2364,0,0.0,0.0,flydsl_fallback +256,16384,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,463.6276,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,1723.7202,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,2187.3478,0,0.0,0.0,flydsl_fallback +256,32768,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,702.7786,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0,3373.9364,moe_ck2stages_gemm2_64x128x128x128_1x1_MulABScaleExpertWeightShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight1_FP4X2_FP4X2_B16,0.0123,4076.715,0,0.0,0.0,flydsl_fallback diff --git a/aiter/configs/model_configs/kimik2_fp8fp4_tuned_fmoe.csv b/aiter/configs/model_configs/kimik2_fp8fp4_tuned_fmoe.csv new file mode 100755 index 0000000000..168c3f9745 --- /dev/null +++ b/aiter/configs/model_configs/kimik2_fp8fp4_tuned_fmoe.csv @@ -0,0 +1,65 @@ +cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1,block_m,ksplit,us1,kernelName1,err1,us2,kernelName2,err2,us,run_1stage,tflops,bw,_tag +256,1024,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,123.0893,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w2_gui_fp8,0.0%,105.2591,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2,0.0%,228.3484,0,394.99,9353.91, +256,2048,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,133.8259,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_w3_gui_fp8,0.0%,156.8906,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_bnt2_persist,0.0%,290.7165,0,620.5,7422.93, +256,4096,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,168.2474,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_gui_fp8,0.0%,292.1704,flydsl_moe2_afp8_wfp4_bf16_t128x256x256_reduce_persist,0.0%,460.4178,0,783.59,4782.63, +256,8192,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,302.4767,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_bnt0_gui_fp8,0.0%,514.8898,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,817.3665,0,882.78,2801.79, +256,16384,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,481.7208,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_bnt0_gui_fp8,0.0%,1006.5923,flydsl_moe2_afp8_wfp4_bf16_t128x256x256_reduce_bnt2,0.0%,1488.3131,0,969.63,1657.08, +256,32768,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,876.5187,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_bnt0_gui_xcd4_fp8,0.0%,1845.739,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_xcd4,0.0%,2722.2577,0,1060.23,1035.38, +256,1024,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,236.2331,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w3_gui_fp8,0.0%,151.5334,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2,0.0%,387.7665,0,465.2,10959.89, +256,2048,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,250.8169,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_gui_fp8,0.0%,208.9952,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_atomic_bnt2_persist,0.0%,459.8121,0,784.62,9290.53, +256,4096,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,321.0343,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_gui_fp8,0.0%,385.9324,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,706.9667,0,1020.63,6104.87, +256,8192,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,563.8091000000001,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_w3_bnt0_gui,0.0%,594.2764,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce,0.0%,1158.0855,0,1246.12,3802.84, +256,16384,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,870.6742,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w2_bnt0_gui_fp8,0.0%,1154.1545,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,2024.8287,0,1425.41,2262.01, +256,32768,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,1581.5182,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w2_bnt0_gui_fp8,0.0%,2214.8461,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,3796.3643,0,1520.52,1299.27, +256,1024,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,241.4155,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w3_gui_fp8,0.0%,157.2231,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2_persist,0.0%,398.6386,0,509.08,10688.6, +256,2048,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,269.3204,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_w2_gui_fp8,0.0%,230.4879,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_bnt2,0.0%,499.8083,0,812.06,8569.1, +256,4096,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,342.8696,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w3_gui_fp8,0.0%,410.5732,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,753.4428,0,1077.39,5742.9, +256,8192,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,589.8479,flydsl_moe1_afp8_wfp4_bf16_t64x256x256_w2_bnt0_gui,0.0%,653.2881,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist,0.0%,1243.136,0,1305.97,3551.53, +256,16384,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,936.1311,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_bnt0_gui_fp8,0.0%,1291.6956,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,2227.8267,0,1457.47,2060.84, +256,32768,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,1700.6587,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_bnt0_gui_fp8,0.0%,2482.132,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,4182.7907,0,1552.55,1181.87, +256,1024,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,126.9676,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w3_gui_fp8,0.0%,117.2557,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2,0.0%,244.2233,0,415.47,8768.43, +256,2048,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,146.7519,flydsl_moe1_afp8_wfp4_bf16_t64x128x256_w3_gui_fp8,0.0%,166.7794,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce,0.0%,313.5313,0,647.26,6900.35, +256,4096,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,194.7301,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_gui_fp8,0.0%,311.4187,flydsl_moe2_afp8_wfp4_bf16_t128x256x256_reduce,0.0%,506.1488,0,801.89,4361.39, +256,8192,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,333.7686,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w2_bnt0_gui_fp8,0.0%,545.4913,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,879.2599,0,923.22,2610.83, +256,16384,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,523.6283,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w2_bnt0_gui_fp8,0.0%,1069.9622,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_bnt2_persist_sbm128,0.0%,1593.5905,0,1018.77,1551.06, +256,32768,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,907.3006,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w3_bnt0_gui_fp8,0.0%,2060.5967,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_xcd4_sbm128,0.0%,2967.8973,0,1094.04,951.54, +256,512,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,119.1612,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w3_gui_fp8,0.0%,79.3586,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2_persist,0.0%,198.5198,0,227.17,10703.92, +256,131072,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,3052.6405,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w2_bnt0_gui_xcd4_fp8,0.0%,6954.0181,flydsl_moe2_afp8_wfp4_bf16_t128x256x256_reduce_xcd4,0.0%,10006.6586,0,1153.72,492.92, +256,512,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,230.262,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w3_gui_fp8,0.0%,136.2811,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2_persist,0.0%,366.5431,0,246.07,11564.45, +256,131072,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,6018.3408,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w3_bnt0_gui_xcd4_fp8,0.0%,8085.1723,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_xcd4_sbm128,0.0%,14103.5131,0,1637.16,499.62, +256,512,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,236.9899,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w3_gui_fp8,0.0%,138.7571,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2_persist,0.0%,375.747,0,270.05,11310.48, +256,131072,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,6384.608,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w3_bnt0_gui_fp8,0.0%,9168.6363,flydsl_moe2_afp8_wfp4_bf16_t64x256x256_reduce_xcd4_sbm128,0.0%,15553.2443,0,1670.13,453.76, +256,512,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,125.4126,flydsl_moe1_afp8_wfp4_bf16_t32x128x256_w3_gui_fp8,0.0%,80.3704,flydsl_moe2_afp8_wfp4_bf16_t32x256x256_atomic_bnt2,0.0%,205.783,0,246.54,10352.87, +256,131072,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,3303.1397,flydsl_moe1_afp8_wfp4_bf16_t128x256x256_w3_bnt0_gui_fp8,0.0%,7645.4133,flydsl_moe2_afp8_wfp4_bf16_t128x256x256_reduce_xcd4_persist,0.0%,10948.553,0,1186.27,451.02, +256,512,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,139.3696,cktile_a8w4_bm32,0.0,106.7266,cktile_a8w4_bm32,0.0,246.0962,0,0.0,0.0,flydsl_fallback +256,131072,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,3897.4695,cktile_a8w4_bm64,0.0,11790.5161,cktile_a8w4_bm64,0.0,15687.9856,0,0.0,0.0,flydsl_fallback +256,512,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,265.1341,cktile_a8w4_bm32,0.0,159.6032,cktile_a8w4_bm32,0.0,424.7373,0,0.0,0.0,flydsl_fallback +256,131072,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,7644.1875,cktile_a8w4_bm64,0.0,12120.2159,cktile_a8w4_bm64,0.0,19764.4034,0,0.0,0.0,flydsl_fallback +256,512,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,283.445,cktile_a8w4_bm32,0.0,160.9444,cktile_a8w4_bm32,0.0,444.3894,0,0.0,0.0,flydsl_fallback +256,131072,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,8892.2449,cktile_a8w4_bm64,0.0,14030.8646,cktile_a8w4_bm64,0.0,22923.1095,0,0.0,0.0,flydsl_fallback +256,512,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,156.2523,cktile_a8w4_bm32,0.0,108.3953,cktile_a8w4_bm32,0.0,264.6476,0,0.0,0.0,flydsl_fallback +256,131072,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,4453.171,cktile_a8w4_bm64,0.0,13495.6183,cktile_a8w4_bm64,0.0,17948.7893,0,0.0,0.0,flydsl_fallback +256,1024,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,142.6997,cktile_a8w4_bm32,0.0,124.1896,cktile_a8w4_bm32,0.0,266.8893,0,0.0,0.0,flydsl_fallback +256,2048,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,173.4389,cktile_a8w4_bm32,0.0,201.3501,cktile_a8w4_bm32,0.0,374.789,0,0.0,0.0,flydsl_fallback +256,4096,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,224.7223,cktile_a8w4_bm32,0.0,380.1632,cktile_a8w4_bm32,0.0,604.8855,0,0.0,0.0,flydsl_fallback +256,8192,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,321.4143,cktile_a8w4_bm64,0.0,731.9281,cktile_a8w4_bm64,0.0,1053.3424,0,0.0,0.0,flydsl_fallback +256,16384,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,572.6357,cktile_a8w4_bm64,0.0,1445.6359,cktile_a8w4_bm64,0.0,2018.2716,0,0.0,0.0,flydsl_fallback +256,32768,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1018.5422,cktile_a8w4_bm64,0.0,2870.1453,cktile_a8w4_bm64,0.0,3888.6875,0,0.0,0.0,flydsl_fallback +256,1024,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,273.0062,cktile_a8w4_bm32,0.0,171.5034,cktile_a8w4_bm32,0.0,444.5096,0,0.0,0.0,flydsl_fallback +256,2048,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,318.9878,cktile_a8w4_bm32,0.0,253.7825,cktile_a8w4_bm32,0.0,572.7703,0,0.0,0.0,flydsl_fallback +256,4096,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,431.1773,cktile_a8w4_bm32,0.0,393.7221,cktile_a8w4_bm32,0.0,824.8994,0,0.0,0.0,flydsl_fallback +256,8192,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,688.4932,cktile_a8w4_bm32,0.0,765.334,cktile_a8w4_bm32,0.0,1453.8272,0,0.0,0.0,flydsl_fallback +256,16384,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,1189.1030999999998,cktile_a8w4_bm32,0.0,1463.1439,cktile_a8w4_bm32,0.0,2652.247,0,0.0,0.0,flydsl_fallback +256,32768,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1988.53,cktile_a8w4_bm64,0.0,3026.5083,cktile_a8w4_bm64,0.0,5015.0383,0,0.0,0.0,flydsl_fallback +256,1024,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,286.284,cktile_a8w4_bm32,0.0,175.2685,cktile_a8w4_bm32,0.0,461.5525,0,0.0,0.0,flydsl_fallback +256,2048,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,349.7174,cktile_a8w4_bm32,0.0,262.9544,cktile_a8w4_bm32,0.0,612.6718,0,0.0,0.0,flydsl_fallback +256,4096,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,441.6842,cktile_a8w4_bm32,0.0,423.2442,cktile_a8w4_bm32,0.0,864.9284,0,0.0,0.0,flydsl_fallback +256,8192,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,723.1021,cktile_a8w4_bm32,0.0,832.2823,cktile_a8w4_bm32,0.0,1555.3844,0,0.0,0.0,flydsl_fallback +256,16384,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1155.2034,cktile_a8w4_bm64,0.0,1754.2978,cktile_a8w4_bm64,0.0,2909.5012,0,0.0,0.0,flydsl_fallback +256,32768,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,2183.6604,cktile_a8w4_bm64,0.0,3390.8062,cktile_a8w4_bm64,0.0,5574.4666,0,0.0,0.0,flydsl_fallback +256,1024,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,161.18650000000002,cktile_a8w4_bm32,0.0,129.6653,cktile_a8w4_bm32,0.0,290.8518,0,0.0,0.0,flydsl_fallback +256,2048,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,193.0266,cktile_a8w4_bm32,0.0,215.5235,cktile_a8w4_bm32,0.0,408.5501,0,0.0,0.0,flydsl_fallback +256,4096,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,237.3953,cktile_a8w4_bm32,0.0,413.8183,cktile_a8w4_bm32,0.0,651.2136,0,0.0,0.0,flydsl_fallback +256,8192,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,344.532,cktile_a8w4_bm64,0.0,814.1671,cktile_a8w4_bm64,0.0,1158.6991,0,0.0,0.0,flydsl_fallback +256,16384,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,607.9427000000001,cktile_a8w4_bm64,0.0,1611.4606,cktile_a8w4_bm64,0.0,2219.4033,0,0.0,0.0,flydsl_fallback +256,32768,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,64,0,1150.4311,cktile_a8w4_bm64,0.0,3221.5736,cktile_a8w4_bm64,0.0,4372.0047,0,0.0,0.0,flydsl_fallback diff --git a/aiter/configs/model_configs/kimik2_fp8fp4_untuned_fmoe.csv b/aiter/configs/model_configs/kimik2_fp8fp4_untuned_fmoe.csv new file mode 100644 index 0000000000..f480eb797c --- /dev/null +++ b/aiter/configs/model_configs/kimik2_fp8fp4_untuned_fmoe.csv @@ -0,0 +1,33 @@ +token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1 +512,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +1024,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +2048,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +4096,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +8192,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +16384,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +32768,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +131072,7168,256,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +512,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +1024,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +2048,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +4096,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +8192,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +16384,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +32768,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +131072,7168,512,384,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +512,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +1024,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +2048,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +4096,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +8192,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +16384,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +32768,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +131072,7168,512,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +512,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +1024,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +2048,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +4096,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +8192,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +16384,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +32768,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 +131072,7168,256,385,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0 diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 60f1c32d40..86bc2915f2 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -263,7 +263,7 @@ def fused_moe_( and a1_scale is not None ): q_dtype_a = dtypes.fp8 - bf16_fp8_bound = 512 + bf16_fp8_bound = 256 if quant_type == QuantType.per_1x32: if activation == ActivationType.Swiglu: if get_gfx() != "gfx950" or M < bf16_fp8_bound: @@ -621,7 +621,7 @@ class MOEMetadata: run_1stage: bool = False has_bias: bool = False use_non_temporal_load: bool = True - fuse_fp4_quant: bool = False + fuse_quant: str = "" def _flydsl_stage1_wrapper( @@ -638,16 +638,16 @@ def _flydsl_stage1_wrapper( w1_scale=None, a1_scale=None, sorted_weights=None, - fuse_fp4_quant=False, - fuse_sort_scale=False, + out_scale=None, + out_scale_sorted=None, + bias1=None, **_kwargs, ): parsed = aiter.ops.flydsl.moe_kernels.get_flydsl_kernel_params(kernelName) if parsed is None: raise ValueError(f"Invalid FlyDSL kernel name: {kernelName}") act = "swiglu" if activation == ActivationType.Swiglu else "silu" - _fq = fuse_fp4_quant or parsed.get("fuse_fp4_quant", False) - _fss = fuse_sort_scale or (_fq and not fuse_sort_scale) + _a_scale_one = parsed.get("a_scale_one", False) return aiter.ops.flydsl.flydsl_moe_stage1( a=hidden_states, w1=w1, @@ -666,12 +666,14 @@ def _flydsl_stage1_wrapper( w1_scale=w1_scale, a1_scale=a1_scale, sorted_weights=sorted_weights, - fuse_fp4_quant=_fq, - fuse_sort_scale=_fss, + use_async_copy=True, k_batch=parsed.get("k_batch", 1), waves_per_eu=parsed.get("waves_per_eu", 3), b_nt=parsed.get("b_nt", 2), - gate_only=parsed.get("gate_only", False), + gate_mode=parsed.get("gate_mode", "separated"), + bias=bias1, + a_scale_one=_a_scale_one, + xcd_swizzle=parsed.get("xcd_swizzle", 0), ) @@ -688,6 +690,7 @@ def _flydsl_stage2_wrapper( w2_scale=None, a2_scale=None, sorted_weights=None, + bias2=None, **_kwargs, ): @@ -713,6 +716,10 @@ def _flydsl_stage2_wrapper( a2_scale=a2_scale, sorted_weights=sorted_weights, sort_block_m=parsed.get("sort_block_m", 0), + b_nt=parsed.get("b_nt", 0), + persist=parsed.get("persist", None), + bias=bias2, + xcd_swizzle=parsed.get("xcd_swizzle", 0), ) @@ -870,7 +877,7 @@ def FinalFunc(): ) in fused_moe_1stage_dict[get_gfx()]: if q_type == QuantType.per_1x128: # for fp8 blockscale, ck has better performance so disable assembly kernel - run_1stage = token > 32 and (inter_dim % 128 == 0) + run_1stage = token > 32 and (inter_dim % 256 == 0) elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8: run_1stage = token > 32 elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8: @@ -948,7 +955,7 @@ def get_block_m() -> int: is_flydsl1 = bool(kernelName1) and kernelName1.startswith("flydsl_") is_flydsl2 = bool(kernelName2) and kernelName2.startswith("flydsl_") if (is_flydsl1 or is_flydsl2) and is_flydsl_available(): - _s1_fq = is_flydsl1 and "_fq" in kernelName1 + _s1_fq = is_flydsl1 and "_fp4" in kernelName1.split("_t")[-1] if is_flydsl1: stage1_func = functools.partial( _flydsl_stage1_wrapper, @@ -980,13 +987,21 @@ def get_block_m() -> int: use_non_temporal_load=use_non_temporal_load, ) + _has_bias = ( + activation == ActivationType.Swiglu + and q_type == QuantType.per_1x32 + and dtype in [dtypes.bf16, dtypes.fp16] + ) + _s1_fp8q = is_flydsl1 and "_fp8" in kernelName1.split("_t")[-1] + _fuse_quant = "fp8" if _s1_fp8q else ("fp4" if _s1_fq else "") return MOEMetadata( stage1_func, stage2_func, block_m, int(ksplit), run_1stage, - fuse_fp4_quant=_s1_fq, + has_bias=_has_bias, + fuse_quant=_fuse_quant, ) if ( dtype in [dtypes.bf16, dtypes.fp16] @@ -1180,7 +1195,10 @@ def fused_moe_2stages( a1 = hidden_states.to(dtypes.fp8) M = sorted_ids.shape[0] N = a1.shape[-1] - a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device) + if metadata.fuse_quant == "fp8": + a1_scale = torch.empty([1], dtype=dtypes.fp8_e8m0, device=a1.device) + else: + a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device) elif quant_type == QuantType.per_1x32: if hidden_states.dtype == dtypes.fp4x2 and a1_scale is not None: @@ -1249,7 +1267,7 @@ def fused_moe_2stages( sorted_ids, sorted_expert_ids, num_valid_ids, - None if metadata.fuse_fp4_quant else a2, + None if metadata.fuse_quant else a2, topk, block_m=block_size_M, a1_scale=a1_scale, @@ -1259,7 +1277,7 @@ def fused_moe_2stages( sorted_weights=sorted_weights if doweight_stage1 else None, **extra_stage1_args, ) - if metadata.fuse_fp4_quant and isinstance(a2, tuple): + if metadata.fuse_quant == "fp4" and isinstance(a2, tuple): a2_raw, a2_scale = a2[0], a2[1] _fp4_bytes = token_num * topk * (inter_dim // 2) a2 = ( @@ -1268,6 +1286,9 @@ def fused_moe_2stages( .view(dtypes.fp4x2) .reshape(token_num, topk, -1) ) + elif metadata.fuse_quant == "fp8" and isinstance(a2, tuple): + a2, a2_scale = a2[0], a2[1] + a2 = a2.view(token_num, topk, -1) elif ( quant_type == QuantType.per_1x32 and dtype in [dtypes.bf16, dtypes.fp16] diff --git a/aiter/jit/utils/moe_recipes.py b/aiter/jit/utils/moe_recipes.py index 56ed009e5f..db555c8a70 100644 --- a/aiter/jit/utils/moe_recipes.py +++ b/aiter/jit/utils/moe_recipes.py @@ -143,6 +143,9 @@ def get_moe_ck2stages_prebuild_variants(aiter_csrc_dir: str) -> List[Dict]: mul_weight_stage = _get_mul_weight_stage(row) need_splitk = _should_include_splitk(row, quant_type) + if activation == "swiglu": + continue + for preshuffle in _infer_preshuffle_modes(b_dtype, quant_type): for splitk in [False, True] if need_splitk else [False]: key = ( diff --git a/aiter/ops/flydsl/kernels/mfma_epilogues.py b/aiter/ops/flydsl/kernels/mfma_epilogues.py index 7966e623ed..e4c4d0f559 100644 --- a/aiter/ops/flydsl/kernels/mfma_epilogues.py +++ b/aiter/ops/flydsl/kernels/mfma_epilogues.py @@ -31,6 +31,7 @@ from flydsl._mlir import ir import flydsl.expr as fx +from flydsl._mlir.dialects.arith import CmpIPredicate from flydsl.expr.typing import T @@ -111,6 +112,12 @@ def c_shuffle_epilog( write_row_to_lds: Callable, precompute_row: Callable | None = None, store_pair: Callable, + # When LDS overflows, split lds_out across two buffers by wave-group. + # Pass the second buffer here; first buffer is `lds_out`. + lds_out_split=None, + # Row offset in lds_out for 8-wave mode (MLIR index value). + # Shifts both write and read LDS indices by lds_row_offset * tile_n elements. + lds_row_offset=None, ): """LDS CShuffle epilogue skeleton. @@ -137,14 +144,173 @@ def c_shuffle_epilog( f"tile_n must be divisible by (CShuffleNLane*EVec) = {cshuffle_nlane*e_vec}, got tile_n={tile_n}" ) + # ===================== Split-LDS mode (early return) ===================== + # When lds_out_split is provided, waves are divided into two groups: + # Group A (waves 0..N/2-1) uses lds_out, columns [0, tile_n/2) + # Group B (waves N/2..N-1) uses lds_out_split, columns [tile_n/2, tile_n) + # Each group writes/reads independently; same barriers synchronise all waves. + if lds_out_split is not None: + if scf is None: + raise ValueError("scf module is required for split-LDS cshuffle") + + _half_n = int(tile_n) // 2 + _half_threads = int(block_size) // 2 + EVec = int(e_vec) + + CShuffleNLane_s = min(int(cshuffle_nlane), _half_n // EVec) + if _half_threads % CShuffleNLane_s != 0: + raise ValueError( + f"half_threads={_half_threads} not divisible by CShuffleNLane_split={CShuffleNLane_s}" + ) + CShuffleMLane_s = _half_threads // CShuffleNLane_s + if int(tile_m) % CShuffleMLane_s != 0: + raise ValueError( + f"tile_m={tile_m} not divisible by CShuffleMLane_split={CShuffleMLane_s}" + ) + m_reps_s = int(tile_m) // CShuffleMLane_s + n_reps_s = _half_n // (CShuffleNLane_s * EVec) + + _half_n_idx = arith.constant(_half_n, index=True) + _half_thr_idx = arith.constant(_half_threads, index=True) + _zero_idx = arith.constant(0, index=True) + + _is_group_b = arith.cmpi(CmpIPredicate.uge, tx, _half_thr_idx) + + # -- write phase (all waves, each to its group's LDS buffer) -- + n_tile_base_v = n_tile_base + col_base_local_a = n_tile_base_v + lane_mod_16 + col_base_local_b = col_base_local_a - _half_n_idx + + def _write_row_split(mi: int, ii: int, row_in_tile, row): + row_base_lds = row_in_tile * _half_n_idx + _if_g = scf.IfOp(_is_group_b) + with ir.InsertionPoint(_if_g.then_block): + write_row_to_lds( + mi=mi, + ii=ii, + row_in_tile=row_in_tile, + row=row, + row_base_lds=row_base_lds, + col_base_local=col_base_local_b, + num_acc_n=num_acc_n, + lds_out=lds_out_split, + ) + scf.YieldOp([]) + with ir.InsertionPoint(_if_g.else_block): + write_row_to_lds( + mi=mi, + ii=ii, + row_in_tile=row_in_tile, + row=row, + row_base_lds=row_base_lds, + col_base_local=col_base_local_a, + num_acc_n=num_acc_n, + lds_out=lds_out, + ) + scf.YieldOp([]) + + gpu.barrier() + default_epilog( + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_write_row_split, + ) + gpu.barrier() + + # -- read phase (each group reads from its own LDS buffer) -- + tx_local = tx - arith.select(_is_group_b, _half_thr_idx, _zero_idx) + c_nlane_s = arith.constant(CShuffleNLane_s, index=True) + m_lane_s = tx_local / c_nlane_s + n_lane_s = tx_local % c_nlane_s + c_evec = arith.constant(EVec, index=True) + + if frag_elem_type is None: + frag_elem_type = T.f16 + vec_frag = T.vec(EVec, frag_elem_type) + bx_m_v = bx_m + by_n_v = by_n + + _precomputed_rows_s = [] + for mr in range_constexpr(m_reps_s): + row_base_m = arith.constant(mr * CShuffleMLane_s, index=True) + row_local = row_base_m + m_lane_s + row = bx_m_v + row_local + row_ctx_raw = ( + precompute_row(row_local=row_local, row=row) + if precompute_row is not None + else None + ) + row_ctx = row_ctx_raw + row_pred = None + if ( + scf is not None + and row_ctx_raw is not None + and isinstance(row_ctx_raw, tuple) + and len(row_ctx_raw) == 2 + ): + row_ctx, row_pred = row_ctx_raw + _precomputed_rows_s.append((row_local, row, row_ctx, row_pred)) + + for mr in range_constexpr(m_reps_s): + row_local, row, row_ctx, row_pred = _precomputed_rows_s[mr] + + def _do_store_row_split(): + row_base_lds = row_local * _half_n_idx + for nr in range_constexpr(n_reps_s): + col_base_nr = arith.constant( + nr * (CShuffleNLane_s * EVec), index=True + ) + col_pair0_local = col_base_nr + (n_lane_s * c_evec) + lds_idx = row_base_lds + col_pair0_local + + _if_ld = scf.IfOp(_is_group_b, [vec_frag]) + with ir.InsertionPoint(_if_ld.then_block): + fb = vector.load_op(vec_frag, lds_out_split, [lds_idx]) + scf.YieldOp([fb]) + with ir.InsertionPoint(_if_ld.else_block): + fa = vector.load_op(vec_frag, lds_out, [lds_idx]) + scf.YieldOp([fa]) + frag = _if_ld.results[0] + + col_pair0 = col_pair0_local + arith.select( + _is_group_b, _half_n_idx, _zero_idx + ) + store_pair( + row_local=row_local, + row=row, + row_ctx=row_ctx, + col_pair0=col_pair0, + col_g0=by_n_v + col_pair0, + frag=frag, + ) + + if row_pred is not None: + _if_row = scf.IfOp(row_pred) + with _if_then(_if_row, scf): + _do_store_row_split() + else: + _do_store_row_split() + + return # split path complete + + # ===================== Standard (non-split) path below ===================== + # ---------------- Step 1: write C tile to LDS (row-major, fp16) ---------------- tile_n_idx = arith.constant(int(tile_n), index=True) n_tile_base_v = n_tile_base col_base_local = n_tile_base_v + lane_mod_16 # index within [0,tile_n) + _lds_row_base_offset = ( + lds_row_offset * tile_n_idx if lds_row_offset is not None else None + ) + def _write_row(mi: int, ii: int, row_in_tile, row): - # row_base_lds = row_in_tile * tile_n row_base_lds = row_in_tile * tile_n_idx + if _lds_row_base_offset is not None: + row_base_lds = row_base_lds + _lds_row_base_offset write_row_to_lds( mi=mi, ii=ii, @@ -189,6 +355,10 @@ def _write_row(mi: int, ii: int, row_in_tile, row): bx_m_v = bx_m by_n_v = by_n + # Batch-precompute all row contexts (sorted_idx loads) before the store loop. + # This issues all buffer_load instructions upfront so the compiler can pipeline + # them instead of serializing each load with s_waitcnt vmcnt(0). + _precomputed_rows = [] for mr in range_constexpr(m_reps_shuffle): row_base_m = arith.constant(mr * CShuffleMLane, index=True) row_local = row_base_m + m_lane @@ -212,8 +382,16 @@ def _write_row(mi: int, ii: int, row_in_tile, row): ): row_ctx, row_pred = row_ctx_raw + _precomputed_rows.append((row_local, row, row_ctx, row_pred)) + + # Now perform LDS reads and stores using the pre-fetched row contexts. + for mr in range_constexpr(m_reps_shuffle): + row_local, row, row_ctx, row_pred = _precomputed_rows[mr] + def _do_store_row(): row_base_lds = row_local * tile_n_idx + if _lds_row_base_offset is not None: + row_base_lds = row_base_lds + _lds_row_base_offset for nr in range_constexpr(n_reps_shuffle): col_base_nr = arith.constant(nr * (CShuffleNLane * EVec), index=True) col_pair0 = col_base_nr + (n_lane * c_evec) # even col within tile diff --git a/aiter/ops/flydsl/kernels/mfma_preshuffle_pipeline.py b/aiter/ops/flydsl/kernels/mfma_preshuffle_pipeline.py index ed8b77a33b..07f555d15e 100644 --- a/aiter/ops/flydsl/kernels/mfma_preshuffle_pipeline.py +++ b/aiter/ops/flydsl/kernels/mfma_preshuffle_pipeline.py @@ -606,6 +606,7 @@ def lds_load_pack_k32( __all__ = [ "PreshuffleBLayout", + "PreshuffleScaleLayout", "buffer_copy_gmem16_dwordx4", "lds_load_pack_k32", "lds_store_4b_xor16", diff --git a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py index fc88cf47f3..81d7bfeb4e 100644 --- a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py +++ b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py @@ -44,6 +44,25 @@ from .layout_utils import crd2idx, idx2crd, get as layout_get import functools +from enum import Enum + + +class GateMode(str, Enum): + """Gate/Up computation strategy for stage1 GEMM. + + SEPARATED: Two separate B-tile streams (gate + up), default mode. + MOCK_GATE_ONLY: Single B-tile stream over full [0, 2*inter_dim), simulates + gate-only by doubling grid X on top of SEPARATED layout. + Requires split-K (k_batch>1). NOT true gate-only. + GATE_ONLY: Reserved for future true gate-only implementation. + INTERLEAVE: Weight rows interleave gate/up (gate[0], up[0], gate[1], ...). + pack_N=2 routes even/odd N subtiles. NOT tied to split-K. + """ + + SEPARATED = "separated" + MOCK_GATE_ONLY = "mock_gate_only" + GATE_ONLY = "gate_only" + INTERLEAVE = "interleave" @contextmanager @@ -104,34 +123,29 @@ def compile_mixed_moe_gemm1( model_dim_pad: int = 0, inter_dim_pad: int = 0, persist_m: int = 1, - fuse_fp4_quant: bool = False, - fuse_sort_scale: bool = False, use_async_copy: bool = False, - waves_per_eu: int = 3, + waves_per_eu: int = 4, k_batch: int = 1, b_nt: int = 0, - gate_only: bool = False, + gate_mode: GateMode = GateMode.SEPARATED, + a_scale_one: bool = False, + xcd_swizzle: int = 0, ): - """Compile stage1 kernel (gate+up with silu) based on stage2 structure. + """Compile stage1 kernel (gate+up with silu/swiglu). - GEMM: silu(X @ W_gate.T) * (X @ W_up.T) -> [tokens*topk, inter_dim] + GEMM: act(X @ W_gate.T, X @ W_up.T) -> [tokens*topk, inter_dim] Direct store (no atomic). When k_batch>1 (split-K), each CTA computes a K-slice and atomically adds gate/up partials. Note: persist_m=1 (no persistence) is optimal for stage1 because K=model_dim is large, so each CTA is already compute-heavy. persist_m>1 serializes M blocks that the GPU can process in parallel. - When gate_only=True (requires k_batch>1), each workgroup computes - only one B-tile stream instead of interleaving gate and up. - The grid X dimension doubles (inter_in / tile_n instead of - inter_in / 2 / tile_n) so that by_n covers the full [0, 2*inter_dim) - range, naturally selecting gate or up rows by position. - This halves per-WG B-VMEM traffic and MFMA count, and the - doubled block count compensates. + gate_mode controls the gate/up computation strategy — see GateMode enum. """ gpu_arch = get_hip_arch() allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") + _state = {} if a_dtype not in ("fp8", "fp16", "int8", "fp4"): raise ValueError( @@ -149,7 +163,7 @@ def compile_mixed_moe_gemm1( is_f4_b = b_dtype == "fp4" sort_block_m = max(32, tile_m) - num_waves = tile_n // 32 + num_waves = min(4, tile_n // 32) total_threads = num_waves * 64 pack_M = 1 if tile_m < 32 else 2 n_per_wave = tile_n // num_waves @@ -186,10 +200,20 @@ def _w_elem_type(): def out_elem(): return T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16) + mock_gate_only = gate_mode is GateMode.MOCK_GATE_ONLY + gate_up_interleave = gate_mode is GateMode.INTERLEAVE + + # Padding semantics: model_dim and inter_dim INCLUDE padding. + # model_dim = model_dim_true + model_dim_pad (K direction) + # inter_dim = inter_dim_true + inter_dim_pad (N direction) + # Tensor sizes use the padded dimensions (inter_dim, model_dim). + # Padding only affects kernel internal logic and grid computation. + _inter_dim_valid = inter_dim - inter_dim_pad + # Split-K validation _is_splitk = k_batch > 1 - if gate_only and not _is_splitk: - raise ValueError("gate_only requires k_batch > 1 (split-K)") + if mock_gate_only and not _is_splitk: + raise ValueError("mock_gate_only requires k_batch > 1 (split-K)") if _is_splitk: _k_per_batch = model_dim // k_batch assert ( @@ -199,15 +223,11 @@ def out_elem(): _k_per_batch % tile_k == 0 ), f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}" - fuse_fp4_quant = False + out_dtype = "bf16" else: _k_per_batch = model_dim _k_dim = _k_per_batch - # Stage1 gate-only: output = [tokens*topk, inter_dim], direct store (accumulate=False) - # Weight layout: [E * 2*inter_dim, model_dim] pre-shuffled; gate = first inter_dim rows per expert - # GEMM: X[tokens, model_dim] @ W_gate[inter_dim, model_dim].T -> [tokens*topk, inter_dim] - bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) if bytes_x_per_tile % total_threads != 0: raise ValueError( @@ -236,45 +256,74 @@ def out_elem(): else: _use_cshuffle_epilog = bool(use_cshuffle_epilog) - _need_quant = fuse_fp4_quant - _need_sort = _need_quant and fuse_sort_scale + _need_fp4 = out_dtype == "fp4" + _need_fp8 = out_dtype == "fp8" + _need_quant = _need_fp4 or _need_fp8 + _need_sort = _need_quant if _need_quant: _use_cshuffle_epilog = True - _fp4q_tag = "_fp4q" if _need_quant else "" + _fp4q_tag = "_fp4q" if _need_fp4 else "" + _fp8q_tag = "_fp8q" if _need_fp8 else "" _sort_tag = "_sort" if _need_sort else "" _async_tag = "_async" if use_async_copy else "" _sk_tag = f"_sk{k_batch}" if _is_splitk else "" - _go_tag = "_go" if gate_only else "" + _go_tag = "_go" if mock_gate_only else "" + _gui_tag = "_gui" if gate_up_interleave else "" + _as1_tag = "_as1" if a_scale_one else "" + _xcd_tag = f"_xcd{xcd_swizzle}" if xcd_swizzle > 0 else "" module_name = ( f"mfma_moe1_silu_mul_a{a_dtype}_w{b_dtype}_{out_s}" - f"_t{tile_m}x{tile_n}x{tile_k}_pm{persist_m}{_fp4q_tag}{_sort_tag}{_async_tag}{_sk_tag}{_go_tag}_v32" + f"_t{tile_m}x{tile_n}x{tile_k}_pm{persist_m}{_fp4q_tag}{_fp8q_tag}{_sort_tag}{_async_tag}{_sk_tag}{_go_tag}{_gui_tag}{_as1_tag}{_xcd_tag}_v32" ).replace("-", "_") - # -- LDS sizing (split ping/pong allocators) -- + # -- LDS sizing -- _cshuffle_elem_bytes = 4 if _need_quant else (4 if out_is_f32 else 2) _single_x_bytes = int(tile_m) * int(lds_stride) * int(a_elem_bytes) lds_out_bytes = ( _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 ) lds_tid_bytes = int(tile_m) * 4 - _buffer_bytes = max(_single_x_bytes, lds_out_bytes) - _buffer_elems = _buffer_bytes if a_elem_bytes == 1 else (_buffer_bytes // 2) + _input_elems = _single_x_bytes if a_elem_bytes == 1 else (_single_x_bytes // 2) + + # Determine whether we need wave-group split for lds_out. + # Standard layout: pong = max(input, lds_out) + tid, ping = input. + # When this overflows, split lds_out into two halves across pong & ping. + _GLOBAL_ALIGN = 1024 + _std_pong = max(_single_x_bytes, lds_out_bytes) + lds_tid_bytes + _std_ping = _single_x_bytes + _std_pong_aligned = allocator_pong._align(_std_pong, 128) + _std_total = allocator_pong._align( + _std_pong_aligned, _GLOBAL_ALIGN + ) + allocator_pong._align(_std_ping, 128) + _lds_limit = {"gfx950": 163840, "gfx942": 65536}.get(gpu_arch, 0) + + _split_lds_out = ( + _lds_limit > 0 + and lds_out_bytes > 0 + and _std_total > _lds_limit + and num_waves >= 2 + ) + + if _split_lds_out: + _half_out_bytes = _cshuffle_elem_bytes * int(tile_m) * (int(tile_n) // 2) + _pong_buffer_bytes = max(_single_x_bytes, _half_out_bytes) + _ping_buffer_bytes = max(_single_x_bytes, _half_out_bytes) + else: + _pong_buffer_bytes = max(_single_x_bytes, lds_out_bytes) + _ping_buffer_bytes = _single_x_bytes def x_lds_elem(): return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) - allocator_pong.ptr = lds_pong_offset + _buffer_bytes + allocator_pong.ptr = lds_pong_offset + _pong_buffer_bytes _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) - allocator_ping.ptr = lds_ping_offset + _buffer_bytes - - # if tile_m == 16: - # waves_per_eu = 1 + allocator_ping.ptr = lds_ping_offset + _ping_buffer_bytes if waves_per_eu is not None and waves_per_eu >= 1: _total_cu_lds = 160 * 1024 @@ -328,19 +377,21 @@ def x_lds_elem(): for ku in range(_pipe_k_unroll): for ni in range(_pipe_num_acc_n): _pipe_b_loads.append(("gate", ku, ni)) - if not gate_only: + if not mock_gate_only and not gate_up_interleave: _pipe_b_loads.append(("up", ku, ni)) # MFMA order: B-major (fix B, cycle all A tiles before next B) # Each entry: one (k, ni) pair; the compute function loops over all mi. # This keeps B operands (from VMEM) fixed while cycling A (from LDS, no wait). + _pipe_num_acc_n_packed = _pipe_num_acc_n // pack_N _pipe_all_mfma = [] for _ku128 in range(_pipe_k_unroll_packed): - for _ikxdl in range(pack_K): - for _inxdl in range(pack_N): - _k_idx = _ku128 * pack_K + _ikxdl - _ni_idx = _inxdl - _pipe_all_mfma.append((_k_idx, _ni_idx, _ikxdl, _inxdl, _ku128)) + for _ni_packed in range(_pipe_num_acc_n_packed): + for _ikxdl in range(pack_K): + for _inxdl in range(pack_N): + _k_idx = _ku128 * pack_K + _ikxdl + _ni_idx = _ni_packed * pack_N + _inxdl + _pipe_all_mfma.append((_k_idx, _ni_idx, _ikxdl, _inxdl, _ku128)) # Group MFMAs per scheduling phase (wider M -> more MFMAs per phase) _pipe_mfma_per_phase = max(1, len(_pipe_all_mfma) // 4) @@ -383,6 +434,10 @@ def x_lds_elem(): _pp_b_loads = [p["b_loads"] for p in _pipe_phases] _pp_has_scale = [p["has_scale"] for p in _pipe_phases] + fp4_ratio = 2 if a_dtype == "fp4" else 1 + gui_ratio = 1 if gate_up_interleave else 2 + _vmcnt_before_barrier = tile_m // 32 // fp4_ratio + tile_n // 32 * gui_ratio + if True: @flyc.kernel @@ -407,7 +462,9 @@ def moe_gemm1( tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) - size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + size_expert_ids_in = arith.index_cast( + ir.IndexType.get(), i32_size_expert_ids_in.ir_value() + ) x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) f32 = T.f32 @@ -423,6 +480,7 @@ def moe_gemm1( # --- Stage1 dimension mapping --- # X: [tokens, model_dim] -- M = sorted tokens, K = model_dim # W: [E*2*inter_dim, model_dim] gate portion -- N = inter_dim + # Out: [tokens*topk, inter_dim] # B preshuffle layout: [E*2*inter_dim, model_dim] # Gate rows for expert e: [e*2*inter_dim, e*2*inter_dim + inter_dim) @@ -448,14 +506,11 @@ def moe_gemm1( arith, c_mn=c_n_total, c_k=arith.constant(model_dim, index=True) ) - _eff_lds_stride = 0 - _eff_tile_k_bytes = 0 + _eff_lds_stride = lds_stride + _eff_tile_k_bytes = tile_k_bytes if const_expr(use_async_copy and a_elem_vec_pack > 1): _eff_lds_stride = lds_stride // a_elem_vec_pack _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack - else: - _eff_lds_stride = lds_stride - _eff_tile_k_bytes = tile_k_bytes shape_lds = fx.make_shape(tile_m, _eff_lds_stride) stride_lds = fx.make_stride(_eff_lds_stride, 1) @@ -464,6 +519,44 @@ def moe_gemm1( tx = gpu.thread_id("x") by = gpu.block_id("x") # tile along inter_dim (N) bx_persist = gpu.block_id("y") # persistent WG index + + if xcd_swizzle > 0: + _NUM_XCDS_S1 = 8 + _c1_sw = arith.constant(1, index=True) + _c_tn_sw = arith.constant(tile_n, index=True) + _c_idp_sw = arith.constant(2 * inter_dim_pad, index=True) + if const_expr(mock_gate_only or gate_up_interleave): + _gx = (n_in - _c_idp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw + else: + _c2_sw = arith.constant(2, index=True) + _gx = ( + (n_in - _c_idp_sw + _c2_sw * _c_tn_sw - _c1_sw) + / _c_tn_sw + / _c2_sw + ) + _c_pm_sw = arith.constant(persist_m, index=True) + _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw + + _linear_id = bx_persist * _gx + by + _num_wgs = _gx * _gy + + _c_xcds = arith.constant(_NUM_XCDS_S1, index=True) + _wgs_per_xcd = _num_wgs / _c_xcds + _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) + + _WGM_S1 = xcd_swizzle + _c_wgm = arith.constant(_WGM_S1, index=True) + _num_wgid_in_group = _c_wgm * _gx + _group_id = _wgid / _num_wgid_in_group + _first_pid_m = _group_id * _c_wgm + _remaining_m = _gy - _first_pid_m + _cmp_m = arith.cmpi(CmpIPredicate.ult, _remaining_m, _c_wgm) + _group_size_m = arith.select(_cmp_m, _remaining_m, _c_wgm) + + _wgid_in_group = _wgid % _num_wgid_in_group + bx_persist = _first_pid_m + (_wgid_in_group % _group_size_m) + by = _wgid_in_group / _group_size_m + by_n = by * arith.constant(tile_n, index=True) k_base_idx = arith.index(0) @@ -478,24 +571,40 @@ def moe_gemm1( base_ptr_pong = allocator_pong.get_base() base_ptr_ping = allocator_ping.get_base() lds_x_pong = SmemPtr( - base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_buffer_elems,) + base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_input_elems,) ).get() lds_x_ping = SmemPtr( - base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_buffer_elems,) + base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_input_elems,) ).get() _lds_out_elem_type = ( T.f32 if _need_quant else (T.bf16 if out_is_bf16 else T.f16) ) - lds_out = ( - SmemPtr( + if _split_lds_out and _use_cshuffle_epilog: + _half_out_elems = int(tile_m) * (int(tile_n) // 2) + lds_out = SmemPtr( base_ptr_pong, lds_pong_offset, _lds_out_elem_type, - shape=(tile_m * tile_n,), + shape=(_half_out_elems,), ).get() - if _use_cshuffle_epilog - else None - ) + lds_out_B = SmemPtr( + base_ptr_ping, + lds_ping_offset, + _lds_out_elem_type, + shape=(_half_out_elems,), + ).get() + else: + lds_out = ( + SmemPtr( + base_ptr_pong, + lds_pong_offset, + _lds_out_elem_type, + shape=(tile_m * tile_n,), + ).get() + if _use_cshuffle_epilog + else None + ) + lds_out_B = None lds_tid = SmemPtr( base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,) ).get() @@ -514,17 +623,6 @@ def moe_gemm1( w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) # Out: [tokens*topk, inter_dim] - out_nbytes_idx = ( - tokens_in - * arith.index(topk) - * n_in - * arith.constant(out_elem_bytes, index=True) - ) - out_nbytes_i32 = arith.index_cast(T.i32, out_nbytes_idx) - buffer_ops.create_buffer_resource( - arg_out, max_size=False, num_records_bytes=out_nbytes_i32 - ) - numids_rsrc = buffer_ops.create_buffer_resource( arg_num_valid_ids, max_size=False, @@ -536,7 +634,7 @@ def moe_gemm1( sx_rsrc = 1 sw_rsrc = 1 - if const_expr(not is_f16_a): + if const_expr(not (is_f16_a or a_scale_one)): # A scale: [sorted_size, model_dim/32] pre-scattered by caller c32 = arith.constant(32, index=True) kblk = k_in / c32 @@ -574,16 +672,20 @@ def moe_gemm1( expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 ) + bias_rsrc = ( + buffer_ops.create_buffer_resource(arg_bias, max_size=False) + if enable_bias + else None + ) # Sorted-scale buffer resource for fused mxfp4 quantization _sorted_scale_cols = inter_dim // 32 _sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32) + sorted_scale_rsrc = None if const_expr(_need_sort): sorted_scale_rsrc = buffer_ops.create_buffer_resource( arg_out_scale_sorted, max_size=False ) - else: - sorted_scale_rsrc = None # ---- persist_m loop (same pattern as stage2) ---- _PERSIST_M = persist_m @@ -714,6 +816,7 @@ def load_x_tile(base_k): gate_n_blk_list = [] up_n_intra_list = [] up_n_blk_list = [] + col_g_list = [] c_n0_static = experts * (2 * inter_dim) // 16 layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) inter_idx = arith.constant(inter_dim, index=True) @@ -721,28 +824,52 @@ def load_x_tile(base_k): for i in range_constexpr(num_acc_n): offset = i * 16 c_offset = arith.constant(offset, index=True) + if const_expr(not gate_up_interleave): + col_g = by_n + n_tile_base + c_offset + lane_mod_16 + col_g_list.append(col_g) global_n = by_n + n_tile_base + c_offset + lane_mod_16 - # Gate: rows [expert_off, expert_off + inter_dim) - # For gate_only, by_n covers [0, 2*inter_dim) so this - # indexes into both gate and up regions naturally. + # Gate/interleave: rows [expert_off, expert_off + 2*inter_dim) gate_row_w = expert_off_idx + global_n gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) gate_n_blk_list.append(layout_get(gate_coord, 0)) gate_n_intra_list.append(layout_get(gate_coord, 1)) - if const_expr(not gate_only): - # Up: rows [expert_off + inter_dim, expert_off + 2*inter_dim) + if const_expr(not mock_gate_only and not gate_up_interleave): up_row_w = gate_row_w + inter_idx up_coord = idx2crd(up_row_w, layout_n_blk_intra) up_n_blk_list.append(layout_get(up_coord, 0)) up_n_intra_list.append(layout_get(up_coord, 1)) + if const_expr(gate_up_interleave): + _gui_num_acc_n_out = num_acc_n // pack_N + for _gui_i in range_constexpr(_gui_num_acc_n_out): + _gui_offset = _gui_i * 16 + _gui_c_offset = arith.constant(_gui_offset, index=True) + _gui_col_g = ( + (by_n + n_tile_base) // arith.constant(2, index=True) + + _gui_c_offset + + lane_mod_16 + ) + col_g_list.append(_gui_col_g) + m_repeat = tile_m // 16 k_unroll = tile_k_bytes // 128 k_unroll_packed = k_unroll // pack_K m_repeat_packed = m_repeat // pack_M num_acc_n_packed = num_acc_n // pack_N + _K_per_ku = tile_k // k_unroll + _pad_k_elems = ( + (model_dim_pad % tile_k) + if (not _is_splitk and model_dim_pad > 0) + else 0 + ) + _pad_ku_skip = _pad_k_elems // _K_per_ku + _tail_ku = k_unroll - _pad_ku_skip + _tail_ku_packed = ( + (_tail_ku + pack_K - 1) // pack_K if _pad_ku_skip > 0 else None + ) + # B load for gate and up separately def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): c64 = arith.constant(64, index=True) @@ -774,12 +901,14 @@ def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): ) return b0, b1 - def load_b_tile(base_k): + def load_b_tile(base_k, ku_limit=k_unroll): """Load B tiles. Returns (gate_b_tile, up_b_tile). - When gate_only, up_b_tile is None.""" + When mock_gate_only or gate_up_interleave, up_b_tile is None.""" gate_b_tile = [] - up_b_tile = [] if not gate_only else None - for ku in range_constexpr(k_unroll): + up_b_tile = ( + [] if (not mock_gate_only and not gate_up_interleave) else None + ) + for ku in range_constexpr(ku_limit): g_packs0, g_packs1 = [], [] u_packs0, u_packs1 = [], [] for ni in range_constexpr(num_acc_n): @@ -788,14 +917,16 @@ def load_b_tile(base_k): ) g_packs0.append(gb0) g_packs1.append(gb1) - if const_expr(not gate_only): + if const_expr( + not mock_gate_only and not gate_up_interleave + ): ub0, ub1 = load_b_packs_k64( base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni] ) u_packs0.append(ub0) u_packs1.append(ub1) gate_b_tile.append((g_packs0, g_packs1)) - if const_expr(not gate_only): + if const_expr(not mock_gate_only and not gate_up_interleave): up_b_tile.append((u_packs0, u_packs1)) return gate_b_tile, up_b_tile @@ -821,7 +952,7 @@ def load_b_tile(base_k): _gate_scale_bases.append( _gate_mni * layout_b_scale.stride_n0 + _scale_lane_elem ) - if const_expr(not gate_only): + if const_expr(not mock_gate_only and not gate_up_interleave): _up_mni = ( expert_off_idx + inter_idx + _col_base ) // arith.constant(32, index=True) @@ -829,12 +960,13 @@ def load_b_tile(base_k): _up_mni * layout_b_scale.stride_n0 + _scale_lane_elem ) - _a_scale_bases = [] - for _mi in range_constexpr(m_repeat_packed): - _a_mni = _mi + bx_m // scale_mn_pack // 16 - _a_scale_bases.append( - _a_mni * layout_a_scale.stride_n0 + _scale_lane_elem - ) + if const_expr(not a_scale_one): + _a_scale_bases = [] + for _mi in range_constexpr(m_repeat_packed): + _a_mni = _mi + bx_m // scale_mn_pack // 16 + _a_scale_bases.append( + _a_mni * layout_a_scale.stride_n0 + _scale_lane_elem + ) _c16_idx = arith.constant(16, index=True) _c2_idx = arith.constant(2, index=True) @@ -888,24 +1020,33 @@ def _rearrange_b_scale(raw_i32): b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) ) - def prefetch_ab_scale_tile(base_k): + if const_expr(a_scale_one): + _as1_const = arith.constant(0x7F7F7F7F, type=T.i32) + _as1_vec = vector.from_elements(T.vec(1, T.i32), [_as1_const]) + + def prefetch_ab_scale_tile(base_k, ku_packed_limit=k_unroll_packed): a_scale_tile = [] gate_b_scale = [] - up_b_scale = [] if not gate_only else None - for ku in range_constexpr(k_unroll_packed): + up_b_scale = ( + [] if (not mock_gate_only and not gate_up_interleave) else None + ) + for ku in range_constexpr(ku_packed_limit): k_off = (ku + base_k) * layout_b_scale.stride_k0 for mi in range_constexpr(m_repeat_packed): - s = buffer_ops.buffer_load( - sx_rsrc, - _a_scale_bases[mi] + k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - s = _rearrange_a_scale(s) - a_scale_tile.append( - vector.from_elements(T.vec(1, T.i32), [s]) - ) + if const_expr(a_scale_one): + a_scale_tile.append(_as1_vec) + else: + s = buffer_ops.buffer_load( + sx_rsrc, + _a_scale_bases[mi] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + s = _rearrange_a_scale(s) + a_scale_tile.append( + vector.from_elements(T.vec(1, T.i32), [s]) + ) for ni in range_constexpr(num_acc_n_packed): gs = buffer_ops.buffer_load( sw_rsrc, @@ -918,7 +1059,9 @@ def prefetch_ab_scale_tile(base_k): gate_b_scale.append( vector.from_elements(T.vec(1, T.i32), [gs]) ) - if const_expr(not gate_only): + if const_expr( + not mock_gate_only and not gate_up_interleave + ): us = buffer_ops.buffer_load( sw_rsrc, _up_scale_bases[ni] + k_off, @@ -1032,10 +1175,10 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): ) return a0, a1 - def prefetch_full_a_from_lds(lds_buffer): + def prefetch_full_a_from_lds(lds_buffer, ku_limit=k_unroll): """Load entire A tile from LDS into registers before compute.""" a_regs = [] - for k_idx in range_constexpr(k_unroll): + for k_idx in range_constexpr(ku_limit): col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack for mi_idx in range_constexpr(m_repeat): mi_val = arith.constant(mi_idx * 16, index=True) @@ -1063,31 +1206,69 @@ def compute_tile( up_b_scale=None, *, prefetch_epilogue=False, + ku_count=k_unroll, ): gate_list = list(acc_gate_in) - up_list = list(acc_up_in) if not gate_only else None + _single_b = mock_gate_only or gate_up_interleave + up_list = None if _single_b else list(acc_up_in) mfma_res_ty = vec4_f32 epilogue_pf = None - if const_expr(prefetch_epilogue and doweight_stage1): - tw_pf = [] - lane_div_16_mul4_pf = lane_div_16 * arith.index(4) - ii_idx_list_pf = [ - arith.constant(ii, index=True) for ii in range(4) - ] - for mi in range_constexpr(m_repeat): - mi_base_pf = arith.constant(mi * 16, index=True) - for ii in range_constexpr(4): - row_off_pf = lane_div_16_mul4_pf + ii_idx_list_pf[ii] - sorted_row_pf = bx_m + mi_base_pf + row_off_pf - tw_pf.append( + bias_pf = None + if const_expr(prefetch_epilogue): + if const_expr(enable_bias): + bias_pf = [] + for ni in range_constexpr(num_acc_n): + if const_expr(gate_up_interleave): + _logical_col = ( + (by_n + n_tile_base) + // arith.constant(2, index=True) + + arith.constant((ni // 2) * 16, index=True) + + lane_mod_16 + ) + _up_off = ( + inter_idx + if (ni % 2 == 1) + else arith.constant(0, index=True) + ) + bias_offset = ( + expert_off_idx + _up_off + _logical_col + ) + else: + global_n = ( + by_n + + n_tile_base + + arith.constant(ni * 16, index=True) + + lane_mod_16 + ) + bias_offset = expert_off_idx + global_n + bias_pf.append( buffer_ops.buffer_load( - sorted_w_rsrc, - sorted_row_pf, - vec_width=1, - dtype=f32, + bias_rsrc, bias_offset, vec_width=1, dtype=f32 ) ) - epilogue_pf = (None, tw_pf, None) + tw_pf = None + if const_expr(doweight_stage1): + tw_pf = [] + lane_div_16_mul4_pf = lane_div_16 * arith.index(4) + ii_idx_list_pf = [ + arith.constant(ii, index=True) for ii in range(4) + ] + for mi in range_constexpr(m_repeat): + mi_base_pf = arith.constant(mi * 16, index=True) + for ii in range_constexpr(4): + row_off_pf = ( + lane_div_16_mul4_pf + ii_idx_list_pf[ii] + ) + sorted_row_pf = bx_m + mi_base_pf + row_off_pf + tw_pf.append( + buffer_ops.buffer_load( + sorted_w_rsrc, + sorted_row_pf, + vec_width=1, + dtype=f32, + ) + ) + epilogue_pf = (None, tw_pf, bias_pf) c0_i64 = arith.constant(0, type=T.i64) vec4_i64 = T.vec(4, T.i64) @@ -1097,9 +1278,10 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) + _eff_packed = (ku_count + pack_K - 1) // pack_K # B-major: fix B (ni), cycle A (mi) -- B from VMEM stays # in registers while A from LDS is repacked per mi. - for ku128 in range_constexpr(k_unroll_packed): + for ku128 in range_constexpr(_eff_packed): for ni in range_constexpr(num_acc_n_packed): gate_bs_i32 = gate_b_scale[ku128 * num_acc_n_packed + ni] gate_bs_val = vector.extract( @@ -1107,85 +1289,88 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): static_position=[0], dynamic_position=[], ) - if const_expr(not gate_only): + if const_expr(not _single_b): up_bs_i32 = up_b_scale[ku128 * num_acc_n_packed + ni] up_bs_val = vector.extract( up_bs_i32, static_position=[0], dynamic_position=[] ) for ikxdl in range_constexpr(pack_K): k_idx = ku128 * pack_K + ikxdl - gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] - if const_expr(not gate_only): - up_bp0, up_bp1 = up_b_tile_in[k_idx] - for inxdl in range_constexpr(pack_N): - ni_idx = ni * pack_N + inxdl - gb0 = gate_bp0[ni_idx] - gb1 = gate_bp1[ni_idx] - gb128 = pack_i64x4_to_i32x8( - gb0, gb1, c0_i64, c0_i64 - ) - if const_expr(not gate_only): - ub0 = up_bp0[ni_idx] - ub1 = up_bp1[ni_idx] - ub128 = pack_i64x4_to_i32x8( - ub0, ub1, c0_i64, c0_i64 - ) - for mi in range_constexpr(m_repeat_packed): - a_scale_i32 = a_scale[ - ku128 * m_repeat_packed + mi - ] - a_scale_val = vector.extract( - a_scale_i32, - static_position=[0], - dynamic_position=[], + if k_idx < ku_count: + gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] + if const_expr(not _single_b): + up_bp0, up_bp1 = up_b_tile_in[k_idx] + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl + gb0 = gate_bp0[ni_idx] + gb1 = gate_bp1[ni_idx] + gb128 = pack_i64x4_to_i32x8( + gb0, gb1, c0_i64, c0_i64 ) - for imxdl in range_constexpr(pack_M): - mi_idx = mi * pack_M + imxdl - _a_reg_idx = k_idx * m_repeat + mi_idx - if const_expr(is_f8_a): - a0, a1, a2, a3 = a_tile_regs[_a_reg_idx] - a128 = pack_i64x4_to_i32x8( - a0, a1, a2, a3 - ) - else: - a0, a1 = a_tile_regs[_a_reg_idx] - a128 = pack_i64x4_to_i32x8( - a0, a1, c0_i64, c0_i64 - ) - acc_idx = mi_idx * num_acc_n + ni_idx - gate_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - gb128, - gate_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - gate_bs_val, - ], - ) + if const_expr(not _single_b): + ub0 = up_bp0[ni_idx] + ub1 = up_bp1[ni_idx] + ub128 = pack_i64x4_to_i32x8( + ub0, ub1, c0_i64, c0_i64 ) - if const_expr(not gate_only): - up_list[acc_idx] = ( + for mi in range_constexpr(m_repeat_packed): + a_scale_i32 = a_scale[ + ku128 * m_repeat_packed + mi + ] + a_scale_val = vector.extract( + a_scale_i32, + static_position=[0], + dynamic_position=[], + ) + for imxdl in range_constexpr(pack_M): + mi_idx = mi * pack_M + imxdl + _a_reg_idx = k_idx * m_repeat + mi_idx + if const_expr(is_f8_a): + a0, a1, a2, a3 = a_tile_regs[ + _a_reg_idx + ] + a128 = pack_i64x4_to_i32x8( + a0, a1, a2, a3 + ) + else: + a0, a1 = a_tile_regs[_a_reg_idx] + a128 = pack_i64x4_to_i32x8( + a0, a1, c0_i64, c0_i64 + ) + acc_idx = mi_idx * num_acc_n + ni_idx + gate_list[acc_idx] = ( rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, [ a128, - ub128, - up_list[acc_idx], + gb128, + gate_list[acc_idx], cbsz, blgp, ikxdl * pack_M + imxdl, a_scale_val, ikxdl * pack_N + inxdl, - up_bs_val, + gate_bs_val, ], ) ) + if const_expr(not _single_b): + up_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + ub128, + up_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + up_bs_val, + ], + ) + ) return gate_list, up_list, epilogue_pf def load_a_subtile(k_idx, mi_idx, lds_buffer): @@ -1200,6 +1385,8 @@ def load_a_subtile(k_idx, mi_idx, lds_buffer): else: return (a0, a1) + _single_b_pipe = mock_gate_only or gate_up_interleave + def compute_bmajor_mfma_phase( all_a_tiles, gate_b_single, @@ -1221,7 +1408,7 @@ def compute_bmajor_mfma_phase( all_a_tiles: flat list indexed by [k*m_repeat + mi]. gate_b_single/up_b_single: (b0, b1) for one specific ni. - When gate_only, up_b_single is None. + When _single_b_pipe (mock_gate_only or interleave), up_b_single is None. a_scale_vals: list of A scale scalars indexed by mi_packed. """ c0_i64 = arith.constant(0, type=T.i64) @@ -1234,7 +1421,7 @@ def _pack(x0, x1, x2, x3): mfma_res_ty = vec4_f32 gb128 = _pack(gate_b_single[0], gate_b_single[1], c0_i64, c0_i64) - if const_expr(not gate_only): + if const_expr(not _single_b_pipe): ub128 = _pack(up_b_single[0], up_b_single[1], c0_i64, c0_i64) for mi_p in range_constexpr(m_repeat_packed): @@ -1263,7 +1450,7 @@ def _pack(x0, x1, x2, x3): gate_bs_val, ], ) - if const_expr(not gate_only): + if const_expr(not _single_b_pipe): up_list[acc_idx] = ( rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, @@ -1312,7 +1499,7 @@ def _interleaved_half( _k_off = _sk * layout_b_scale.stride_k0 rocdl.sched_barrier(0) - rocdl.s_waitcnt(3) + rocdl.s_waitcnt(_vmcnt_before_barrier) _barrier() rocdl.sched_barrier(0) @@ -1333,17 +1520,25 @@ def _interleaved_half( dynamic_position=[], ) ) - _prev_gsv = vector.extract( - prev_gate_bs[0], - static_position=[0], - dynamic_position=[], - ) - if const_expr(not gate_only): - _prev_usv = vector.extract( - prev_up_bs[0], - static_position=[0], - dynamic_position=[], + _prev_gsv_list = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + _prev_gsv_list.append( + vector.extract( + prev_gate_bs[_gs_ni], + static_position=[0], + dynamic_position=[], + ) ) + if const_expr(not _single_b_pipe): + _prev_usv_list = [] + for _us_ni in range_constexpr(num_acc_n_packed): + _prev_usv_list.append( + vector.extract( + prev_up_bs[_us_ni], + static_position=[0], + dynamic_position=[], + ) + ) # ---- Execute phases from unified schedule ---- _a_all = {} @@ -1355,31 +1550,38 @@ def _interleaved_half( if const_expr(_pp_has_scale[_p]): _new_as_list = [] for _mi_p in range_constexpr(m_repeat_packed): - _raw_as = buffer_ops.buffer_load( - sx_rsrc, - _a_scale_bases[_mi_p] + _k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - _new_as_list.append(_rearrange_a_scale(_raw_as)) - _new_gs = buffer_ops.buffer_load( - sw_rsrc, - _gate_scale_bases[0] + _k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - _new_gs = _rearrange_b_scale(_new_gs) - if const_expr(not gate_only): - _new_us = buffer_ops.buffer_load( + if const_expr(a_scale_one): + _new_as_list.append(_as1_const) + else: + _raw_as = buffer_ops.buffer_load( + sx_rsrc, + _a_scale_bases[_mi_p] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_as_list.append(_rearrange_a_scale(_raw_as)) + _new_gs_list = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + _gs_raw = buffer_ops.buffer_load( sw_rsrc, - _up_scale_bases[0] + _k_off, + _gate_scale_bases[_gs_ni] + _k_off, vec_width=1, dtype=T.i32, cache_modifier=0, ) - _new_us = _rearrange_b_scale(_new_us) + _new_gs_list.append(_rearrange_b_scale(_gs_raw)) + if const_expr(not _single_b_pipe): + _new_us_list = [] + for _us_ni in range_constexpr(num_acc_n_packed): + _us_raw = buffer_ops.buffer_load( + sw_rsrc, + _up_scale_bases[_us_ni] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_us_list.append(_rearrange_b_scale(_us_raw)) # B VMEM loads for _b_j in range_constexpr(len(_pp_b_loads[_p])): @@ -1414,12 +1616,13 @@ def _interleaved_half( rocdl.s_setprio(1) for _m_j in range_constexpr(len(_pp_mfma[_p])): _k_idx, _ni_idx, _ikxdl, _inxdl, _ku128 = _pp_mfma[_p][_m_j] + _ni_packed_idx = _ni_idx // pack_N _up_b_single = ( ( prev_up_w[_k_idx][0][_ni_idx], prev_up_w[_k_idx][1][_ni_idx], ) - if not gate_only + if not _single_b_pipe else None ) compute_bmajor_mfma_phase( @@ -1430,8 +1633,12 @@ def _interleaved_half( ), _up_b_single, _prev_asvs, - _prev_gsv, - _prev_usv if not gate_only else None, + _prev_gsv_list[_ni_packed_idx], + ( + _prev_usv_list[_ni_packed_idx] + if not _single_b_pipe + else None + ), acc_gate, acc_up, _k_idx, @@ -1449,7 +1656,7 @@ def _interleaved_half( cur_a_tile.append(_a_all[(_k, _mi)]) cur_gate_w = [] - cur_up_w = None if gate_only else [] + cur_up_w = None if _single_b_pipe else [] for ku in range_constexpr(k_unroll): g_packs0, g_packs1 = [], [] u_packs0, u_packs1 = [], [] @@ -1457,12 +1664,12 @@ def _interleaved_half( g = _b_gate_all[(ku, ni)] g_packs0.append(g[0]) g_packs1.append(g[1]) - if const_expr(not gate_only): + if const_expr(not _single_b_pipe): u = _b_up_all[(ku, ni)] u_packs0.append(u[0]) u_packs1.append(u[1]) cur_gate_w.append((g_packs0, g_packs1)) - if const_expr(not gate_only): + if const_expr(not _single_b_pipe): cur_up_w.append((u_packs0, u_packs1)) cur_a_scale = [] @@ -1473,9 +1680,21 @@ def _interleaved_half( [_new_as_list[_mi_p]], ) ) - cur_gate_bs = [vector.from_elements(T.vec(1, T.i32), [_new_gs])] - if const_expr(not gate_only): - cur_up_bs = [vector.from_elements(T.vec(1, T.i32), [_new_us])] + cur_gate_bs = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + cur_gate_bs.append( + vector.from_elements( + T.vec(1, T.i32), [_new_gs_list[_gs_ni]] + ) + ) + if const_expr(not _single_b_pipe): + cur_up_bs = [] + for _us_ni in range_constexpr(num_acc_n_packed): + cur_up_bs.append( + vector.from_elements( + T.vec(1, T.i32), [_new_us_list[_us_ni]] + ) + ) else: cur_up_bs = None @@ -1520,7 +1739,9 @@ def _interleaved_half( scf.YieldOp([]) acc_gate = [acc_init] * num_acc_n * m_repeat - acc_up = [acc_init] * num_acc_n * m_repeat if not gate_only else None + acc_up = ( + [acc_init] * num_acc_n * m_repeat if not _single_b_pipe else None + ) _k1 = k_base_idx + arith.constant(tile_k, index=True) rocdl.sched_barrier(0) @@ -1533,7 +1754,8 @@ def _interleaved_half( _k0_b = k_base_idx // arith.constant(2, index=True) gate_w0, up_w0 = load_b_tile(_k0_b) # Prime the deep pipeline: DMA K=tile_k -> ping (1 tile ahead) - # rocdl.s_waitcnt(8) + if const_expr(use_async_copy): + rocdl.s_waitcnt(0) gpu.barrier() rocdl.sched_barrier(0) a_tile_pong = prefetch_full_a_from_lds(lds_x_pong) @@ -1547,99 +1769,10 @@ def _interleaved_half( k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) if const_expr(k_main2_py < 0): k_main2_py = 0 + gate_w_pong = gate_w0 up_w_pong = up_w0 - def _sched_hints_stage1_gate_up(): - """Stage1 hot-loop scheduler adapted from the gate/up gufusion pipeline. - - The original hot loop doubles the B-side VMEM and MFMA streams: - - gate B load + up B load - - gate B-scale load + up B-scale load - - gate MFMA + up MFMA - - The scheduler API here is less expressive than the original - `__builtin_amdgcn_sched_group_barrier`, so we encode the same - idea with a compact heuristic: - - always double MFMA groups (`num_acc_n * 2`) - - use 2 VMEM groups only when the N tile is wide enough to - sustain the extra B-side traffic (`num_acc_n >= 4`) - - otherwise keep 1 VMEM group to avoid over-throttling the - smaller `tile_n=128` kernels - """ - # mfma_group = num_acc_n * 2 - # mfma_total = (k_unroll * 2) * m_repeat * mfma_group - # mfma_per_iter = 2 * mfma_group - # sche_iters = ( - # 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - # ) - - # # Approximate the doubled B-side prefetch pressure. - # vmem_groups = 2 if int(num_acc_n) >= 4 else 1 - - # rocdl.sched_dsrd(2) - # rocdl.sched_mfma(2) - # rocdl.sched_dsrd(1) - # rocdl.sched_mfma(1) - # rocdl.sched_dsrd(1) - # rocdl.sched_mfma(1) - - # dswr_tail = num_x_loads - # if dswr_tail > sche_iters: - # dswr_tail = sche_iters - # dswr_start = sche_iters - dswr_tail - - # for sche_i in range_constexpr(sche_iters): - # rocdl.sched_vmem(vmem_groups) - # rocdl.sched_mfma(mfma_group) - # rocdl.sched_dsrd(1) - # rocdl.sched_mfma(mfma_group) - # if sche_i >= dswr_start - 1: - # rocdl.sched_dswr(1) - # rocdl.sched_barrier(0) - - if const_expr(use_async_copy): - a_vmem_load = max(1, tile_m // 32) - mfma_group = a_vmem_load - rocdl.sched_vmem(a_vmem_load) - - rocdl.sched_mfma(mfma_group) - - b_vmem_total = k_unroll * num_acc_n * 2 - vmem_count = b_vmem_total + 2 + a_vmem_load - - if const_expr(tile_m == 16): - for i in range_constexpr(2): - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - for i in range_constexpr(9): - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - else: - for i in range_constexpr(a_vmem_load * 4): - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - - if const_expr(tile_m == 32): - for i in range_constexpr(vmem_count - a_vmem_load * 4): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - elif const_expr(tile_m == 64): - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - rocdl.sched_vmem(1) - rocdl.sched_mfma(2) - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - rocdl.sched_vmem(1) - rocdl.sched_mfma(2) - - rocdl.sched_barrier(0) - rocdl.sched_barrier(0) if const_expr(k_main2_py > 0): @@ -1720,6 +1853,7 @@ def _sched_hints_stage1_gate_up(): gate_bs_pong, up_bs_pong, prefetch_epilogue=True, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, ) else: _k_tail_rel = arith.constant(_k_dim - tile_k, index=True) @@ -1729,12 +1863,22 @@ def _sched_hints_stage1_gate_up(): prefetch_x_to_lds(k_tail1, lds_x_ping) else: x_regs_ping = load_x_tile(k_tail1) - gate_w_ping, up_w_ping = load_b_tile( - k_tail1 // arith.constant(2, index=True) - ) - a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // arith.constant(pack_K * 128, index=True) - ) + if _pad_ku_skip > 0: + gate_w_ping, up_w_ping = load_b_tile( + k_tail1 // arith.constant(2, index=True), + ku_limit=_tail_ku, + ) + a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( + k_tail1 // arith.constant(pack_K * 128, index=True), + ku_packed_limit=_tail_ku_packed, + ) + else: + gate_w_ping, up_w_ping = load_b_tile( + k_tail1 // arith.constant(2, index=True) + ) + a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( + k_tail1 // arith.constant(pack_K * 128, index=True) + ) acc_gate, acc_up, _ = compute_tile( acc_gate, acc_up, @@ -1749,7 +1893,12 @@ def _sched_hints_stage1_gate_up(): store_x_tile_to_lds(x_regs_ping, lds_x_ping) rocdl.s_waitcnt(0) _barrier() - a_tile_ping = prefetch_full_a_from_lds(lds_x_ping) + if _pad_ku_skip > 0: + a_tile_ping = prefetch_full_a_from_lds( + lds_x_ping, ku_limit=_tail_ku + ) + else: + a_tile_ping = prefetch_full_a_from_lds(lds_x_ping) acc_gate, acc_up, epilogue_pf = compute_tile( acc_gate, acc_up, @@ -1760,10 +1909,24 @@ def _sched_hints_stage1_gate_up(): gate_bs_ping, up_bs_ping, prefetch_epilogue=True, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, ) - # silu(gate) * up in f32 before epilogue - # silu(x) = x * sigmoid(x); use HW fast path: exp2, rcp + bias_pf = None + if const_expr(epilogue_pf is not None): + _, _, bias_pf = epilogue_pf + + # Activation helpers (f32 element-wise on vec4_f32) + def _silu_elem(g): + """silu(x) = x * sigmoid(x); HW fast path: exp2, rcp""" + neg_log2e = arith.constant(-1.4426950408889634, type=f32) + t = g * neg_log2e + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + one = arith.constant(1.0, type=f32) + den = one + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return g * sig + def _silu_mul_vec4(gate_v4, up_v4): """Element-wise silu(gate) * up on vec4_f32.""" result_elems = [] @@ -1774,20 +1937,128 @@ def _silu_mul_vec4(gate_v4, up_v4): u = vector.extract( up_v4, static_position=[ei], dynamic_position=[] ) - neg_log2e = arith.constant(-1.4426950408889634, type=f32) - t = g * neg_log2e + result_elems.append(_silu_elem(g) * u) + return vector.from_elements(vec4_f32, result_elems) + + def _swiglu_mul_vec4(gate_v4, up_v4): + """Element-wise swiglu(gate, up) on vec4_f32. + swiglu(g, u) = g * sigmoid(alpha * g) * (u + 1) + with clamping: gate <= limit, -limit <= up <= limit. + """ + result_elems = [] + _alpha = arith.constant(1.702, type=f32) + _limit = arith.constant(7.0, type=f32) + _neg_limit = arith.constant(-7.0, type=f32) + _one = arith.constant(1.0, type=f32) + _neg_log2e = arith.constant(-1.4426950408889634, type=f32) + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] + ) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] + ) + g = arith.minimumf(g, _limit) + u = arith.minimumf(u, _limit) + u = arith.maximumf(u, _neg_limit) + t = g * _alpha * _neg_log2e emu = llvm.call_intrinsic( f32, "llvm.amdgcn.exp2.f32", [t], [], [] ) - one = arith.constant(1.0, type=f32) - den = one + emu + den = _one + emu sig = llvm.call_intrinsic( f32, "llvm.amdgcn.rcp.f32", [den], [], [] ) - result_elems.append(g * sig * u) + result_elems.append(g * sig * (u + _one)) return vector.from_elements(vec4_f32, result_elems) - if const_expr(not _is_splitk): + def _act_vec4(gate_v4, up_v4): + """Dispatch activation based on `act` parameter.""" + if act == "swiglu": + return _swiglu_mul_vec4(gate_v4, up_v4) + else: + return _silu_mul_vec4(gate_v4, up_v4) + + # Add bias to raw GEMM accumulators before activation. + # bias layout: [E, 2*inter_dim] flat f32 (non-interleaved: gate then up). + # For gate_up_interleave, map physical column to logical bias offset. + if const_expr(enable_bias and not _is_splitk): + if const_expr(bias_pf is not None): + _bias_gate_vals = bias_pf + else: + _bias_gate_vals = [] + for _ni in range_constexpr(num_acc_n): + if const_expr(gate_up_interleave): + _logical_col = ( + (by_n + n_tile_base) + // arith.constant(2, index=True) + + arith.constant((_ni // 2) * 16, index=True) + + lane_mod_16 + ) + _up_off = ( + inter_idx + if (_ni % 2 == 1) + else arith.constant(0, index=True) + ) + _bias_off = expert_off_idx + _up_off + _logical_col + else: + _bn = ( + by_n + + n_tile_base + + arith.constant(_ni * 16, index=True) + + lane_mod_16 + ) + _bias_off = expert_off_idx + _bn + _bias_gate_vals.append( + buffer_ops.buffer_load( + bias_rsrc, _bias_off, vec_width=1, dtype=f32 + ) + ) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + _bsplat = vector.from_elements( + vec4_f32, [_bias_gate_vals[_ni]] * 4 + ) + acc_gate[_aidx] = arith.addf(acc_gate[_aidx], _bsplat) + + if const_expr(not (mock_gate_only or gate_up_interleave)): + _bias_up_vals = [] + for _ni in range_constexpr(num_acc_n): + _bn = ( + by_n + + n_tile_base + + arith.constant(_ni * 16, index=True) + + lane_mod_16 + ) + _bias_up_vals.append( + buffer_ops.buffer_load( + bias_rsrc, + expert_off_idx + inter_idx + _bn, + vec_width=1, + dtype=f32, + ) + ) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + _bsplat = vector.from_elements( + vec4_f32, [_bias_up_vals[_ni]] * 4 + ) + acc_up[_aidx] = arith.addf(acc_up[_aidx], _bsplat) + + if const_expr(gate_up_interleave and not _is_splitk): + _gui_out_n = num_acc_n // pack_N + acc = [None] * (_gui_out_n * m_repeat) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(_gui_out_n): + _g_idx = _mi * num_acc_n + _ni * pack_N + _u_idx = _g_idx + 1 + _out_idx = _mi * _gui_out_n + _ni + acc[_out_idx] = _act_vec4( + acc_gate[_g_idx], acc_gate[_u_idx] + ) + elif const_expr(not _is_splitk): acc = [None] * (int(num_acc_n) * int(m_repeat)) for _mi in range_constexpr(m_repeat): for _ni in range_constexpr(num_acc_n): @@ -1798,8 +2069,9 @@ def _silu_mul_vec4(gate_v4, up_v4): # Output: out[(t*topk+s) * inter_dim + col] = silu(gate) * up # For split-K: skip silu, output gate/up separately with atomic add tw_pf = None + bias_pf = None if const_expr(epilogue_pf is not None): - _, tw_pf, _ = epilogue_pf + _, tw_pf, bias_pf = epilogue_pf mask24_i32 = arith.constant(0xFFFFFF) topk_i32_v = topk_i32 @@ -1829,7 +2101,6 @@ def write_row_to_lds( col_base_local, num_acc_n: int, lds_out, - acc_v, ): if const_expr(_apply_weight): tw_idx = (mi * 4) + ii @@ -1843,7 +2114,7 @@ def write_row_to_lds( col_local = col_base_local + (ni * 16) acc_idx = mi * num_acc_n + ni v = vector.extract( - acc_v[acc_idx], static_position=[ii], dynamic_position=[] + acc[acc_idx], static_position=[ii], dynamic_position=[] ) if const_expr(_apply_weight): v = v * tw @@ -1862,7 +2133,11 @@ def write_row_to_lds( _out_row_stride = ( inter_dim * 2 * out_elem_bytes if _is_splitk - else (inter_dim // 2 if _need_quant else inter_dim * out_elem_bytes) + else ( + inter_dim // 2 + if _need_fp4 + else (inter_dim if _need_fp8 else inter_dim * out_elem_bytes) + ) ) def precompute_row(*, row_local, row): @@ -1889,32 +2164,6 @@ def _idx_to_llvm_ptr(idx_val, addr_space=1): ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") return llvm.inttoptr(ptr_ty, i64_raw) - def _make_write_row_to_lds(acc_v): - def _write_row_to_lds_bound( - *, - mi: int, - ii: int, - row_in_tile, - row, - row_base_lds, - col_base_local, - num_acc_n: int, - lds_out, - ): - return write_row_to_lds( - mi=mi, - ii=ii, - row_in_tile=row_in_tile, - row=row, - row_base_lds=row_base_lds, - col_base_local=col_base_local, - num_acc_n=num_acc_n, - lds_out=lds_out, - acc_v=acc_v, - ) - - return _write_row_to_lds_bound - _e_vec = _e_vec_s1 _e_vec_sk = 2 _cshuffle_nlane = min(32, tile_n // _e_vec) @@ -1947,6 +2196,10 @@ def _write_row_to_lds_bound( _c0x80000000_i32 = arith.constant(0x80000000, type=T.i32) _c0_f32 = arith.constant(0.0, type=T.f32) + _c8_i32 = arith.constant(8, type=T.i32) + _fp_headroom = 2 if _need_fp4 else (8 if _need_fp8 else 0) + _c_headroom_i32 = arith.constant(_fp_headroom, type=T.i32) + def _f32_to_e2m1(qx_f32): """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" qx = qx_f32.bitcast(T.i32) @@ -1995,57 +2248,136 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): max_i32 = local_max.bitcast(T.i32) max_rounded = (max_i32 + _c0x200000_i32) & _c0xFF800000_i32 exp_field = max_rounded >> _c23_i32 - e8m0_biased = arith.maxsi(exp_field - _c2_i32, _c0_i32) + e8m0_biased = arith.maxsi(exp_field - _c_headroom_i32, _c0_i32) quant_exp = _c254_i32 - e8m0_biased quant_scale = (quant_exp << _c23_i32).bitcast(T.f32) - fp4_vals = [] - for i in range_constexpr(_e_vec): - scaled_v = frag_vals[i] * quant_scale - fp4_vals.append(_f32_to_e2m1(scaled_v)) - - packed_i32 = fp4_vals[0] | (fp4_vals[1] << _c4_i32) - for k in range_constexpr(1, _e_vec // 2): - byte_k = fp4_vals[2 * k] | (fp4_vals[2 * k + 1] << _c4_i32) - packed_i32 = packed_i32 | ( - byte_k << arith.constant(k * 8, type=T.i32) - ) + if const_expr(_need_fp4): + fp4_vals = [] + for i in range_constexpr(_e_vec): + scaled_v = frag_vals[i] * quant_scale + fp4_vals.append(_f32_to_e2m1(scaled_v)) - ptr_addr_idx = row_byte_base + col_g0 / arith.constant( - 2, index=True - ) - out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) - _pack_bytes = _e_vec // 2 - if const_expr(_pack_bytes == 1): - store_val = arith.TruncIOp(T.i8, packed_i32) - store_raw = ( - store_val._value - if hasattr(store_val, "_value") - else store_val - ) - llvm.StoreOp( - store_raw, out_ptr_v, alignment=1, nontemporal=True - ) - elif const_expr(_pack_bytes == 2): - store_val = arith.TruncIOp(T.i16, packed_i32) - store_raw = ( - store_val._value - if hasattr(store_val, "_value") - else store_val - ) - llvm.StoreOp( - store_raw, out_ptr_v, alignment=2, nontemporal=True - ) - else: - packed_raw = ( - packed_i32._value - if hasattr(packed_i32, "_value") - else packed_i32 - ) - llvm.StoreOp( - packed_raw, out_ptr_v, alignment=4, nontemporal=True + packed_i32 = fp4_vals[0] | (fp4_vals[1] << _c4_i32) + for k in range_constexpr(1, _e_vec // 2): + byte_k = fp4_vals[2 * k] | ( + fp4_vals[2 * k + 1] << _c4_i32 + ) + packed_i32 = packed_i32 | ( + byte_k << arith.constant(k * 8, type=T.i32) + ) + + ptr_addr_idx = row_byte_base + col_g0 / arith.constant( + 2, index=True ) + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + _pack_bytes = _e_vec // 2 + if const_expr(_pack_bytes == 1): + store_val = arith.TruncIOp(T.i8, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, out_ptr_v, alignment=1, nontemporal=True + ) + elif const_expr(_pack_bytes == 2): + store_val = arith.TruncIOp(T.i16, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, out_ptr_v, alignment=2, nontemporal=True + ) + else: + packed_raw = ( + packed_i32._value + if hasattr(packed_i32, "_value") + else packed_i32 + ) + llvm.StoreOp( + packed_raw, out_ptr_v, alignment=4, nontemporal=True + ) + + elif const_expr(_need_fp8): + scaled_vals = [] + for i in range_constexpr(_e_vec): + scaled_vals.append(frag_vals[i] * quant_scale) + + ptr_addr_idx = row_byte_base + col_g0 + if const_expr(_e_vec <= 4): + packed_i32 = _c0_i32 + for _w in range_constexpr(_e_vec // 2): + packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[2 * _w], + scaled_vals[2 * _w + 1], + packed_i32, + _w, + ) + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + if _e_vec == 2: + store_val = arith.TruncIOp(T.i16, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, + out_ptr_v, + alignment=2, + nontemporal=True, + ) + else: + packed_raw = ( + packed_i32._value + if hasattr(packed_i32, "_value") + else packed_i32 + ) + llvm.StoreOp( + packed_raw, + out_ptr_v, + alignment=4, + nontemporal=True, + ) + else: + for _wg in range_constexpr(_e_vec // 4): + _b = _wg * 4 + packed_w = _c0_i32 + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[_b], + scaled_vals[_b + 1], + packed_w, + 0, + ) + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[_b + 2], + scaled_vals[_b + 3], + packed_w, + 1, + ) + word_ptr = ptr_addr_idx + arith.constant( + _wg * 4, index=True + ) + out_ptr_v = _idx_to_llvm_ptr(word_ptr) + packed_raw = ( + packed_w._value + if hasattr(packed_w, "_value") + else packed_w + ) + llvm.StoreOp( + packed_raw, + out_ptr_v, + alignment=4, + nontemporal=True, + ) if const_expr(_need_sort): col_g0_i32 = arith.index_cast(T.i32, col_g0) @@ -2115,8 +2447,40 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): else (ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()) ) - if const_expr(gate_only): - # gate_only: single pass, by_n covers full [0, 2*inter_dim) + if const_expr(gate_up_interleave and not _is_splitk): + # gui without splitk: acc has activation applied, halved N + _gui_eff_n = _gui_out_n + _gui_tile_n = tile_n // 2 + _gui_cshuffle_nlane = min(32, _gui_tile_n // _e_vec) + _gui_by_n = by_n / arith.constant(2, index=True) + _gui_n_tile_base = n_tile_base / arith.constant(2, index=True) + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=_gui_tile_n, + e_vec=_e_vec, + cshuffle_nlane=_gui_cshuffle_nlane, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=_gui_eff_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=_gui_by_n, + n_tile_base=_gui_n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + elif const_expr(mock_gate_only or (gate_up_interleave and _is_splitk)): + # mock_gate_only: single pass, by_n covers full [0, 2*inter_dim) _eff_e_vec = _e_vec_sk acc = acc_gate c_shuffle_epilog( @@ -2140,9 +2504,10 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=_make_write_row_to_lds(acc), + write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, + lds_out_split=lds_out_B, ) elif const_expr(_is_splitk): # Two-pass epilogue: gate then up, each with atomic add @@ -2172,9 +2537,10 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=_make_write_row_to_lds(acc), + write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, + lds_out_split=lds_out_B, ) gpu.barrier() @@ -2203,9 +2569,10 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=_make_write_row_to_lds(acc), + write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, + lds_out_split=lds_out_B, ) else: c_shuffle_epilog( @@ -2229,9 +2596,10 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=_make_write_row_to_lds(acc), + write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, + lds_out_split=lds_out_B, ) _if_blk = scf.IfOp(blk_valid) @@ -2262,12 +2630,12 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): inter_dim_pad, use_cshuffle_epilog, persist_m, - fuse_fp4_quant, - fuse_sort_scale, use_async_copy, waves_per_eu, k_batch, - gate_only, + gate_mode, + a_scale_one, + xcd_swizzle, ) @flyc.jit @@ -2298,13 +2666,15 @@ def launch_mixed_moe_gemm1( allocator_ping.finalize() inter_in = arith.index_cast(ir.IndexType.get(), i32_inter_in.ir_value()) - if const_expr(gate_only): - gx = inter_in / arith.constant(tile_n, index=True) + tile_n_index = arith.constant(tile_n, index=True) + inter_dim_pad_total = arith.constant(2 * inter_dim_pad, index=True) + if const_expr(mock_gate_only or gate_up_interleave): + gx = (inter_in - inter_dim_pad_total + tile_n_index - 1) / tile_n_index else: gx = ( - inter_in + (inter_in - inter_dim_pad_total + 2 * tile_n_index - 1) + / tile_n_index / arith.constant(2, index=True) - / arith.constant(tile_n, index=True) ) _c_pm_l = arith.constant(persist_m, index=True) gy = ( @@ -2358,6 +2728,8 @@ def compile_mixed_moe_gemm2( inter_dim_pad: int = 0, persist_m: int = 4, sort_block_m: int = 0, + b_nt: int = 2, + xcd_swizzle: int = 0, ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. @@ -2397,7 +2769,9 @@ def compile_mixed_moe_gemm2( ) gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") + allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") + allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") + _state = {} if a_dtype not in ("fp8", "fp16", "int8", "fp4"): raise ValueError( @@ -2514,6 +2888,13 @@ def _scale_elem_type(): pad_k = 0 if _use_lds128 else 8 lds_stride = tile_k + pad_k + if a_elem_vec_pack > 1: + _eff_lds_stride = lds_stride // a_elem_vec_pack + _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack + else: + _eff_lds_stride = lds_stride + _eff_tile_k_bytes = tile_k_bytes + if out_is_f32: # Match origin/dev_a16w4: f32 output uses scalar atomics and does NOT use the CShuffle epilogue. _use_cshuffle_epilog = ( @@ -2559,29 +2940,37 @@ def out_elem(): _cu_num = 0 _sbm_tag = "" if _sort_block_m == tile_m else f"_sbm{_sort_block_m}" _pm_tag = f"_persist_cu{_cu_num}" if _persistent else f"_pm{persist_m}" + _xcd_tag = f"_xcd{xcd_swizzle}" if xcd_swizzle > 0 else "" module_name = ( f"mfma_moe2_a{a_dtype}_w{b_dtype}_{out_s}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_vscale_fix3{_pm_tag}{_sbm_tag}" + f"_vscale_fix3{_pm_tag}{_sbm_tag}{_xcd_tag}" ).replace("-", "_") # -- LDS sizing (pure Python; no MLIR Context needed) --------------------- - # Reuse a single allocation for both: - # - ping-pong A2 tiles (2 * tile_m * lds_stride * elem_bytes bytes) - # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) + # Ping-pong A2 tiles via separate allocators (like stage1). + _single_x_bytes = int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + _cshuffle_elem_bytes_s2 = 2 # f16/bf16 = 2 bytes lds_out_bytes = ( - 2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 - ) # f16 bytes + _cshuffle_elem_bytes_s2 * int(tile_m) * int(tile_n) + if _use_cshuffle_epilog + else 0 + ) lds_tid_bytes = int(tile_m) * 4 - lds_total_bytes = max(lds_x_bytes, lds_out_bytes) + lds_tid_bytes - lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) + _input_elems = _single_x_bytes if a_elem_bytes == 1 else (_single_x_bytes // 2) + + _pong_buffer_bytes = max(_single_x_bytes, lds_out_bytes) + _ping_buffer_bytes = _single_x_bytes def x_lds_elem(): return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) - lds_alloc_bytes = int(lds_total_elems) * int(a_elem_bytes) - lds_alloc_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = lds_alloc_offset + lds_alloc_bytes + lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) + allocator_pong.ptr = lds_pong_offset + _pong_buffer_bytes + _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) + allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes + + lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) + allocator_ping.ptr = lds_ping_offset + _ping_buffer_bytes if True: @@ -2606,9 +2995,10 @@ def moe_gemm2( tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) - size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + size_expert_ids_in = arith.index_cast( + ir.IndexType.get(), i32_size_expert_ids_in.ir_value() + ) x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) - # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. f32 = T.f32 i32 = T.i32 i64 = T.i64 @@ -2651,45 +3041,71 @@ def check_c_k_valid_gate(base_k): arith, c_mn=c_n_total, c_k=c_k_orig ) - shape_lds = fx.make_shape(tile_m, tile_k) - stride_lds = fx.make_stride(lds_stride, 1) + shape_lds = fx.make_shape(tile_m, _eff_lds_stride) + stride_lds = fx.make_stride(_eff_lds_stride, 1) layout_lds = fx.make_layout(shape_lds, stride_lds) tx = gpu.thread_id("x") by = gpu.block_id("x") # tile along model_dim (N-dim) bx_persist = gpu.block_id("y") # persistent WG index (M-dim) + if const_expr(xcd_swizzle > 0): + _NUM_XCDS_S = 8 + _c1_sw = arith.constant(1, index=True) + _c_tn_sw = arith.constant(tile_n, index=True) + _c_mdp_sw = arith.constant(model_dim_pad, index=True) + _gx = (n_in - _c_mdp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw + if const_expr(_persistent): + _gy = arith.constant(_cu_num, index=True) + else: + _c_pm_sw = arith.constant(persist_m, index=True) + _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw + + _linear_id = bx_persist * _gx + by + _num_wgs = _gx * _gy + + _c_xcds = arith.constant(_NUM_XCDS_S, index=True) + _wgs_per_xcd = _num_wgs / _c_xcds + _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) + + _WGM_S = xcd_swizzle + _c_wgm = arith.constant(_WGM_S, index=True) + _num_wgid_in_group = _c_wgm * _gx + _group_id = _wgid / _num_wgid_in_group + _first_pid_m = _group_id * _c_wgm + _remaining_m = _gy - _first_pid_m + _cmp_m = arith.cmpi(CmpIPredicate.ult, _remaining_m, _c_wgm) + _group_size_m = arith.select(_cmp_m, _remaining_m, _c_wgm) + + _wgid_in_group = _wgid % _num_wgid_in_group + bx_persist = _first_pid_m + (_wgid_in_group % _group_size_m) + by = _wgid_in_group / _group_size_m + # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). - k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) + k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) layout_tx_wave_lane = fx.make_layout((4, 64), stride=(64, 1)) layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) - base_ptr = allocator.get_base() - lds_x_ptr = SmemPtr( - base_ptr, - lds_alloc_offset, - x_lds_elem(), - shape=(lds_total_elems,), - ) - lds_x = lds_x_ptr.get() - # Alias the same underlying LDS bytes as f16/bf16 for epilogue shuffle. + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + lds_x_pong = SmemPtr( + base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_input_elems,) + ).get() + lds_x_ping = SmemPtr( + base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_input_elems,) + ).get() lds_out = ( SmemPtr( - base_ptr, - lds_x_ptr.byte_offset, + base_ptr_pong, + lds_pong_offset, (T.bf16 if out_is_bf16 else T.f16), shape=(tile_m * tile_n,), ).get() if _use_cshuffle_epilog else None ) - - # lds_tid: alias LDS after max(x, out) for sorted_idx preload - _lds_x_b = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) - _lds_out_b = 2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 - _lds_tid_off = max(_lds_x_b, _lds_out_b) lds_tid = SmemPtr( - base_ptr, lds_x_ptr.byte_offset + _lds_tid_off, T.i32, shape=(tile_m,) + base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,) ).get() # Buffer resources. @@ -2748,7 +3164,7 @@ def check_c_k_valid_gate(base_k): sx_rsrc = 1 sw_rsrc = 1 if const_expr(not is_f16_a): - if const_expr(is_f4_a): + if const_expr(is_f4_a or is_f8_a): # A2 microscale: e8m0 in sorted layout [sorted_size, K/32]. # Caller must pre-scatter a2_scale via moe_mxfp4_sort. kblk = _div_pow2(k_in, 32) @@ -2861,7 +3277,7 @@ def check_c_k_valid_gate(base_k): expert_i32 = buffer_ops.buffer_load( expert_rsrc, sort_blk, vec_width=1, dtype=T.i32 ) - expert_idx = arith.index_cast(T.index, expert_i32) + expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) exp_valid = arith.cmpi( CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) ) @@ -3017,7 +3433,7 @@ def load_x(idx_i32): t_safe = arith.select(ts_valid, t_i32, arith.constant(0)) s_safe = arith.select(ts_valid, s_i32, arith.constant(0)) row_ts_i32 = t_safe * topk_i32 + s_safe - row_ts_idx = arith.index_cast(T.index, row_ts_i32) + row_ts_idx = arith.index_cast(ir.IndexType.get(), row_ts_i32) x_row_base_div4.append(row_ts_idx * c_k_div4) @@ -3072,8 +3488,13 @@ def load_x_tile(base_k): _n_scale_shift_i32 = None n_intra_list = [None] * num_acc_n n_blk_list = [None] * num_acc_n + col_g_list = [None] * num_acc_n for i in range_constexpr(num_acc_n): offset = i * 16 + col_g = by_n + n_tile_base + col_g = _div_pow2(col_g, 2) + offset + col_g = col_g + lane_mod_16 + col_g_list[i] = col_g c_offset = arith.constant(offset, index=True) global_n = by_n + n_tile_base + c_offset + lane_mod_16 n_blk_list[i] = _div_pow2(global_n, 16) @@ -3087,6 +3508,16 @@ def load_x_tile(base_k): m_repeat_packed = m_repeat // pack_M num_acc_n_packed = num_acc_n // pack_N + _K_per_ku_s2 = tile_k // k_unroll + _pad_k_elems_s2 = (inter_dim_pad % tile_k) if inter_dim_pad > 0 else 0 + _pad_ku_skip_s2 = _pad_k_elems_s2 // _K_per_ku_s2 + _tail_ku_s2 = k_unroll - _pad_ku_skip_s2 + _tail_ku_packed_s2 = ( + (_tail_ku_s2 + pack_K - 1) // pack_K + if _pad_ku_skip_s2 > 0 + else None + ) + # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_packs_k64(base_k, ku: int, ni: int): """Load one K64-byte B micro-step: single 16B load, split into 2x i64.""" @@ -3119,6 +3550,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): vec_elems=vec_elems, elem_bytes=b_elem_bytes, offset_in_bytes=(b_elem_bytes == 1), + cache_modifier=b_nt, ) b_i64x2 = vector.bitcast(vec2_i64, b16) b0 = vector.extract( @@ -3129,9 +3561,9 @@ def load_b_packs_k64(base_k, ku: int, ni: int): ) return b0, b1 - def load_b_tile(base_k): + def load_b_tile(base_k, ku_limit=k_unroll): b_tile = [] - for ku in range_constexpr(k_unroll): + for ku in range_constexpr(ku_limit): packs0 = [] packs1 = [] for ni in range_constexpr(num_acc_n): @@ -3192,9 +3624,11 @@ def _apply_k_shift(scale_vec, k_shift_bits): return vector.from_elements(T.vec(1, T.i32), [val]) return scale_vec - def load_b_scale_tile(base_k, k_shift_bits=0): + def load_b_scale_tile( + base_k, k_shift_bits=0, ku_packed_limit=k_unroll_packed + ): b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): + for ku in range_constexpr(ku_packed_limit): for ni in range_constexpr(num_acc_n_packed): scale = load_scale( arg_scale_w, @@ -3214,9 +3648,11 @@ def load_b_scale_tile(base_k, k_shift_bits=0): b_scale_tile.append(scale) return b_scale_tile - def load_a_scale_tile(base_k, k_shift_bits=0): + def load_a_scale_tile( + base_k, k_shift_bits=0, ku_packed_limit=k_unroll_packed + ): a_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): + for ku in range_constexpr(ku_packed_limit): for mi in range_constexpr(m_repeat_packed): scale = load_scale( arg_scale_x, @@ -3229,17 +3665,25 @@ def load_a_scale_tile(base_k, k_shift_bits=0): a_scale_tile.append(scale) return a_scale_tile - def prefetch_ab_scale_tile(base_k, k_shift_bits=0): + def prefetch_ab_scale_tile( + base_k, k_shift_bits=0, ku_packed_limit=k_unroll_packed + ): return [ - load_a_scale_tile(base_k, k_shift_bits), - load_b_scale_tile(base_k, k_shift_bits), + load_a_scale_tile( + base_k, k_shift_bits, ku_packed_limit=ku_packed_limit + ), + load_b_scale_tile( + base_k, k_shift_bits, ku_packed_limit=ku_packed_limit + ), ] vec8_x = T.vec(vec8_elems, x_elem) vec4_x_lds = T.vec(vec4_elems, x_elem) - # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + # ---- Pipeline helpers: store X tile to LDS (unused in DMA path) ---- + _lds_base_zero = arith.index(0) + + def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] @@ -3247,14 +3691,14 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): lds_store_16b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec16_ty=vec16_x, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) @@ -3262,14 +3706,14 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): lds_store_8b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec8_ty=vec8_x, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x2=vec_x_in_parts[i], elem_bytes=elem_bytes, ) @@ -3277,21 +3721,20 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): lds_store_4b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec4_ty=vec4_x_lds, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x1=vec_x_in_parts[i], elem_bytes=elem_bytes, ) # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- - def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): - # Swizzle in bytes, then convert to element offset for memref indexing. + def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16( curr_row_a_lds, col_base, k_blocks16 ) @@ -3300,10 +3743,8 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) ) - # Pass as list so layout_utils.crd2idx uses static arith path idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) - idx_a16 = idx_a16 + lds_base - loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) + loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract( a_i64x2, static_position=[0], dynamic_position=[] @@ -3316,7 +3757,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): def compute_tile( acc_in, b_tile_in, - lds_base, + lds_buffer, a_scale=None, b_scale=None, *, @@ -3324,6 +3765,7 @@ def compute_tile( a0_prefetch=None, a1_prefetch=None, b_hi_loader=None, + ku_count=k_unroll, ): if const_expr(b_hi_loader is not None): b_tile_full = [None] * k_unroll @@ -3391,7 +3833,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for _bhi_i in range_constexpr(len(_b_hi)): b_tile_full[_b_split_ku + _bhi_i] = _b_hi[_bhi_i] - for k_idx in range_constexpr(k_unroll): + for k_idx in range_constexpr(ku_count): ku128 = k_idx >> _pack_K_shift ikxdl = k_idx & _pack_K_mask @@ -3440,13 +3882,13 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): a0, a1 = a1_prefetch else: a0, a1 = lds_load_packs_k64( - curr_row_a_lds, col_base0, lds_base + curr_row_a_lds, col_base0, lds_buffer ) if const_expr(is_f8_a): col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64( - curr_row_a_lds, col_base1, lds_base + curr_row_a_lds, col_base1, lds_buffer ) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) else: @@ -3464,7 +3906,6 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): ) acc_idx = mi_idx * num_acc_n + ni_idx - rocdl.sched_barrier(0) acc_list[acc_idx] = ( rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, @@ -3485,55 +3926,68 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): return acc_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- - lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) - lds_base_cur = arith.index(0) - lds_base_nxt = lds_tile_elems + # ---- Async DMA: GMEM -> LDS (bypasses VGPR, like stage1) ---- + _dma_bytes = 16 + _wave_size = 64 + _eff_bytes_per_buffer = ( + int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + ) + _num_dma_loads = max( + 1, _eff_bytes_per_buffer // (total_threads * _dma_bytes) + ) + + def dma_x_tile_to_lds(base_k, lds_buffer): + c4_idx = arith.index(4) + base_k_div4 = _div_pow2( + _div_pow2(base_k, int(a_elem_vec_pack)) + * arith.constant(int(a_elem_bytes), index=True), + 4, + ) + + lds_ptr_i64 = None + for i in range_constexpr(_num_dma_loads): + row_local_i = x_row_local[i] + col_local_i32_i = x_col_local_i32[i] + col_local_sw = swizzle_xor16( + row_local_i, col_local_i32_i * c4_idx, k_blocks16 + ) + row_k_dw = x_row_base_div4[i] + base_k_div4 + global_byte_idx = row_k_dw * c4_idx + col_local_sw + global_offset = arith.index_cast(T.i32, global_byte_idx) + + if const_expr(i == 0): + lds_addr = memref.extract_aligned_pointer_as_index( + lds_buffer + ) + wave_id * arith.constant( + _wave_size * _dma_bytes, index=True + ) + lds_ptr_i64 = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + else: + lds_ptr_i64 = lds_ptr_i64 + arith.constant( + total_threads * _dma_bytes, type=T.i64 + ) + + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + + rocdl.raw_ptr_buffer_load_lds( + x_rsrc, + lds_ptr, + arith.constant(_dma_bytes, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + ) + + def prefetch_x_to_lds(base_k, lds_buffer): + dma_x_tile_to_lds(base_k, lds_buffer) rocdl.sched_barrier(0) def hot_loop_scheduler(): - # - MFMA group size per "slot": num_acc_n - # - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n - # - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration. - mfma_group = num_acc_n - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = ( - 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - ) - - rocdl.sched_dsrd(2) - rocdl.sched_mfma(1) - if const_expr(tile_m == 16): - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - if const_expr(tile_m == 16): - rocdl.sched_vmem(1) - if const_expr(num_acc_n < 4): - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if const_expr(tile_m == 16): - rocdl.sched_vmem(1) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if const_expr(tile_m == 16): - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - - # DS-write hints near the end: match total A LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if const_expr(dswr_tail > sche_iters): - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) - if const_expr(sche_i >= dswr_start - 1): - rocdl.sched_dswr(1) - rocdl.sched_barrier(0) def _k_shift_bits(k_py): @@ -3560,7 +4014,7 @@ def _k_base(k_py): gpu.barrier() - # Prologue -- B-first. + # Prologue -- B-first + async DMA X(0) -> pong. k0 = arith.index(0) if const_expr(_b_split_enabled): b_cur = load_b_tile_lo(k0) @@ -3569,26 +4023,20 @@ def _k_base(k_py): a_scale_pong, b_scale_pong = prefetch_ab_scale_tile( _k_base(0), _k_shift_bits(0) ) - # scheduling fence to prevent LLVM from deferring - # the scale buffer_loads past the upcoming barrier. rocdl.sched_barrier(0) - x_regs0 = load_x_tile(k0) - store_x_tile_to_lds(x_regs0, lds_base_cur) + prefetch_x_to_lds(k0, lds_x_pong) + rocdl.s_waitcnt(0) gpu.barrier() acc = [acc_init] * num_acc_n * m_repeat - lds_base_pong = lds_base_cur - lds_base_ping = lds_base_nxt - # Cross-tile A0+A1 LDS prefetch: issue both ds_reads back-to-back - # (??2) so LDS bandwidth is fully utilized and the second read - # completes during MFMA #1/#2 execution. + # Cross-tile A0+A1 LDS prefetch from pong buffer. a0_prefetch_pong = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_pong + row_a_lds, col_offset_base, lds_x_pong ) _a1_col_base = col_offset_base + 128 // a_elem_vec_pack a1_prefetch_pong = ( - lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_pong) + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_pong) if pack_K >= 2 else None ) @@ -3619,10 +4067,12 @@ def _make_b_hi_loader(base_k): if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): + rocdl.sched_barrier(0) k_iv = arith.index(k_iv_py) next_k1 = k_iv + tile_k next_k1_bk = next_k1 // 2 - x_regs_ping = load_x_tile(next_k1) + # DMA X(next_k1) -> ping (non-blocking, overlaps with compute) + prefetch_x_to_lds(next_k1, lds_x_ping) b_ping_lo = ( load_b_tile_lo(next_k1_bk) if _b_split_enabled @@ -3635,7 +4085,7 @@ def _make_b_hi_loader(base_k): acc, _ = compute_tile( acc, b_pong, - lds_base_pong, + lds_x_pong, a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, @@ -3646,16 +4096,16 @@ def _make_b_hi_loader(base_k): else None ), ) - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - # hot_loop_scheduler() + hot_loop_scheduler() + rocdl.s_waitcnt(0) gpu.barrier() # Cross-tile prefetch for the ping tile we are about to compute. a0_prefetch_ping = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_ping + row_a_lds, col_offset_base, lds_x_ping ) a1_prefetch_ping = ( - lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_ping) + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_ping) if pack_K >= 2 else None ) @@ -3663,7 +4113,8 @@ def _make_b_hi_loader(base_k): next_k2 = k_iv + c2_tile_k next_k2_py = k_iv_py + tile_k * 2 next_k2_bk = next_k2 // 2 - x_regs_pong = load_x_tile(next_k2) + # DMA X(next_k2) -> pong (non-blocking, overlaps with compute) + prefetch_x_to_lds(next_k2, lds_x_pong) b_pong = ( load_b_tile_lo(next_k2_bk) if _b_split_enabled @@ -3676,7 +4127,7 @@ def _make_b_hi_loader(base_k): acc, _ = compute_tile( acc, b_ping_lo, - lds_base_ping, + lds_x_ping, a_scale_ping, b_scale_ping, a0_prefetch=a0_prefetch_ping, @@ -3688,26 +4139,25 @@ def _make_b_hi_loader(base_k): ), ) k0_pong_bk = next_k2_bk - store_x_tile_to_lds(x_regs_pong, lds_base_pong) - # hot_loop_scheduler() + hot_loop_scheduler() gpu.barrier() # Cross-tile prefetch for the next pong tile. a0_prefetch_pong = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_pong + row_a_lds, col_offset_base, lds_x_pong ) a1_prefetch_pong = ( - lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_pong) + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_pong) if pack_K >= 2 else None ) if const_expr(odd_k_tiles): - # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). + # Tail: single remaining tile (already in pong buffer). acc, epilogue_pf = compute_tile( acc, b_pong, - lds_base_pong, + lds_x_pong, a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, @@ -3716,6 +4166,7 @@ def _make_b_hi_loader(base_k): b_hi_loader=( _make_b_hi_loader(k0_pong_bk) if _b_split_enabled else None ), + ku_count=_tail_ku_s2 if _pad_ku_skip_s2 > 0 else k_unroll, ) else: @@ -3725,20 +4176,29 @@ def _make_b_hi_loader(base_k): int(inter_dim) + tile_k - 1 ) // tile_k * tile_k - tile_k k_tail1_bk = k_tail1 // 2 - x_regs_ping = load_x_tile(k_tail1) - b_ping_lo = ( - load_b_tile_lo(k_tail1_bk) - if _b_split_enabled - else load_b_tile(k_tail1_bk) - ) - a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( - _k_base(k_tail1_py), _k_shift_bits(k_tail1_py) - ) + # DMA tail X -> ping + prefetch_x_to_lds(k_tail1, lds_x_ping) + if const_expr(_pad_ku_skip_s2 > 0): + b_ping_lo = load_b_tile(k_tail1_bk, ku_limit=_tail_ku_s2) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( + _k_base(k_tail1_py), + _k_shift_bits(k_tail1_py), + ku_packed_limit=_tail_ku_packed_s2, + ) + else: + b_ping_lo = ( + load_b_tile_lo(k_tail1_bk) + if _b_split_enabled + else load_b_tile(k_tail1_bk) + ) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( + _k_base(k_tail1_py), _k_shift_bits(k_tail1_py) + ) acc, _ = compute_tile( acc, b_pong, - lds_base_pong, + lds_x_pong, a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, @@ -3748,40 +4208,48 @@ def _make_b_hi_loader(base_k): ), ) - store_x_tile_to_lds(x_regs_ping, lds_base_ping) # hot_loop_scheduler() + rocdl.s_waitcnt(0) gpu.barrier() # Epilogue tile with sw prefetch. a0_prefetch_ping = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_ping + row_a_lds, col_offset_base, lds_x_ping ) a1_prefetch_ping = ( - lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_ping) - if pack_K >= 2 + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_ping) + if pack_K >= 2 and (_pad_ku_skip_s2 == 0 or _tail_ku_s2 >= 2) else None ) acc, epilogue_pf = compute_tile( acc, b_ping_lo, - lds_base_ping, + lds_x_ping, a_scale_ping, b_scale_ping, a0_prefetch=a0_prefetch_ping, a1_prefetch=a1_prefetch_ping, prefetch_epilogue=True, b_hi_loader=( - _make_b_hi_loader(k_tail1_bk) if _b_split_enabled else None + None + if _pad_ku_skip_s2 > 0 + else ( + _make_b_hi_loader(k_tail1_bk) + if _b_split_enabled + else None + ) ), + ku_count=_tail_ku_s2 if _pad_ku_skip_s2 > 0 else k_unroll, ) # ---------------- Epilogue: LDS CShuffle + atomic half2 (x2) ---------------- # Reuse the shared helper so GEMM / MoE kernels share the exact same CShuffle skeleton. + sw_pf = None tw_pf = None bias_pf = None if const_expr(epilogue_pf is not None): - _, tw_pf, bias_pf = epilogue_pf + sw_pf, tw_pf, bias_pf = epilogue_pf mask24_i32 = arith.constant(0xFFFFFF) topk_i32_v = topk_i32 @@ -3814,7 +4282,7 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): _llvm_ptr_ty, arg_out ) out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) - out_base_idx = arith.index_cast(T.index, out_base_i64) + out_base_idx = arith.index_cast(ir.IndexType.get(), out_base_i64) def write_row_to_lds( *, @@ -3827,6 +4295,20 @@ def write_row_to_lds( num_acc_n: int, lds_out, ): + # Match origin/dev_a16w4: rely on sentinel padded rows + hardware OOB behavior. + fused2 = buffer_ops.buffer_load( + sorted_rsrc, row, vec_width=1, dtype=T.i32 + ) + t2 = fused2 & mask24_i32 + s2 = fused2 >> 24 + + t_ok = arith.cmpi(CmpIPredicate.ult, t2, tokens_i32) + s_ok = arith.cmpi(CmpIPredicate.ult, s2, topk_i32_v) + ts_ok = arith.andi(t_ok, s_ok) + t2_safe = arith.select(ts_ok, t2, arith.constant(0)) + s2_safe = arith.select(ts_ok, s2, arith.constant(0)) + t2_safe * topk_i32_v + s2_safe + if const_expr(doweight_stage2): tw_idx = (mi * 4) + ii if const_expr(tw_pf is not None): @@ -3995,6 +4477,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): persist_m, _sort_block_m, _cu_num if _persistent else 0, + xcd_swizzle, ) @flyc.jit @@ -4016,14 +4499,19 @@ def launch_mixed_moe_gemm2( stream: fx.Stream, ): _ = _cache_tag - allocator.finalized = False + allocator_pong.finalized = False + allocator_ping.finalized = False ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): - allocator.finalize() + allocator_pong.finalize() + allocator_ping.finalize() - n_in = arith.index_cast(T.index, i32_n_in) - gx = n_in / arith.constant(tile_n, index=True) - gy = arith.constant(0, index=True) + n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) + _tile_n_idx = arith.constant(tile_n, index=True) + _model_dim_pad_idx = arith.constant(model_dim_pad, index=True) + gx = ( + n_in - _model_dim_pad_idx + _tile_n_idx - arith.constant(1, index=True) + ) / _tile_n_idx if _persistent: gy = arith.constant(_cu_num, index=True) else: diff --git a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py index 0766ec00a8..ebf6245be3 100644 --- a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py +++ b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py @@ -1,28 +1,20 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -"""Fused silu_and_mul + MXFP4 quantization + sorted-scale write kernel (FlyDSL). +"""Fused silu_and_mul + quantization + sorted-scale write kernel (FlyDSL). Designed for split-K MOE stage1 post-processing: input : tmp_out (token_num * topk, inter_dim * 2) bf16 sorted : sorted_token_ids (sorted_len,) i32 -- packed (token<<0 | slot<<24) num_valid_ids (1,) i32 - output : out_fp4 raw byte buffer -- FP4x2 packed, row stride = inter_dim//2 - out_scale_sorted raw byte buffer -- tiled E8M0 scale (same layout as moe_mxfp4_sort) - -Grid: (num_sorted_rows, 1, 1) -- one workgroup per sorted row (including blockM padding). -Block: (BLOCK_THREADS, 1, 1) - -Each workgroup: - 1. Loads sorted_token_ids[bid] -> (token_id, slot_id) -> row = token_id * topk + slot_id - 2. If bid < num_valid_ids (valid row): - a. Reads gate = tmp_out[row, 0:inter_dim], up = tmp_out[row, inter_dim:2*inter_dim] - b. Computes silu(gate) * up in f32 - c. Per-1x32 MXFP4 quant -> writes packed FP4 + E8M0 scale in tiled layout - 3. If bid >= num_valid_ids (blockM padding row): - a. Writes zero FP4 bytes to out_fp4 - b. Writes zero E8M0 scale to out_scale_sorted (keeps tiled layout consistent) + output : out raw byte buffer (FP4x2, FP8, or BF16 depending on quant_mode) + out_scale_sorted raw byte buffer -- tiled E8M0 scale (quant_mode fp4/fp8 only) + +Compile options: + quant_mode : "fp4" | "fp8" | "none" + gui_layout : False -> gate-up separated [gate_0:N, up_0:N] + True -> block-interleaved [gate_0:16, up_0:16, gate_16:32, ...] """ import flydsl.compiler as flyc @@ -40,45 +32,65 @@ WARP_SIZE = 64 -def build_silu_and_mul_fq_module(inter_dim: int, topk: int): - """Return a JIT launcher for fused silu_and_mul + mxfp4 quant + scale sort. +def build_silu_and_mul_fq_module( + inter_dim: int, + topk: int, + quant_mode: str = "fp4", + gui_layout: bool = False, +): + """Return a JIT launcher for fused silu_and_mul + optional quant + scale sort. Parameters ---------- inter_dim : int Output columns of stage1 (after activation). Input has inter_dim*2 cols. - Must be divisible by 32 (MXFP4 block size). + Must be divisible by 32 (quant block size). topk : int Number of expert slots per token. + quant_mode : str + "fp4" -> MXFP4 output + e8m0 scale (tiled layout) + "fp8" -> MXFP8 (e4m3fn) output + e8m0 scale (tiled layout) + "none" -> bf16 output, no quantization (out_scale_sorted ignored) + gui_layout : bool + False -> input is gate-up separated [gate_0:N | up_0:N] + True -> input is block-interleaved [gate_0:16, up_0:16, gate_16:32, ...] """ assert inter_dim % 32 == 0, f"inter_dim={inter_dim} must be divisible by 32" + _need_fp4 = quant_mode == "fp4" + _need_fp8 = quant_mode == "fp8" + _need_quant = _need_fp4 or _need_fp8 + assert _need_fp4 or _need_fp8 or quant_mode == "none" scale_cols = inter_dim // 32 ELEMS_PER_THREAD = (inter_dim + BLOCK_THREADS - 1) // BLOCK_THREADS - # VEC: number of f32 elements each thread handles; must be even for FP4 packing VEC = max(ELEMS_PER_THREAD, 2) if VEC % 2 != 0: VEC += 1 assert 32 % VEC == 0, f"VEC={VEC} must divide 32 evenly" - # threads that actually participate in a 32-element quant group + if gui_layout: + assert VEC <= 16, f"VEC={VEC} must be <=16 for block-interleave layout" THREADS_PER_QUANT_BLK = 32 // VEC - # shuffle distances for intra-group reduction SHUFFLE_DISTS = [] d = 1 while d < THREADS_PER_QUANT_BLK: SHUFFLE_DISTS.append(d) d *= 2 + _fp_headroom = 2 if _need_fp4 else 8 + elem_bytes_bf16 = 2 + if _need_fp8: + from flydsl._mlir.dialects import rocdl + @flyc.kernel def silu_and_mul_fq_kernel( - x: fx.Tensor, # (token_num*topk, inter_dim*2) bf16 - out_fp4: fx.Tensor, # raw byte buffer for packed FP4 output - out_scale_sorted: fx.Tensor, # raw byte buffer for sorted E8M0 scales - sorted_ids: fx.Tensor, # (sorted_len,) i32 - num_valid_ids: fx.Tensor, # (1,) i32 - token_num: Int32, # host scalar + x: fx.Tensor, + out_buf: fx.Tensor, + out_scale_sorted: fx.Tensor, + sorted_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + token_num: Int32, ): bid = fx.block_idx.x tid = fx.thread_idx.x @@ -112,15 +124,15 @@ def silu_and_mul_fq_kernel( c0x80000000_i32 = arith.constant(0x80000000, type=i32) c0_f32 = arith.constant(0.0, type=f32) c1_f32 = arith.constant(1.0, type=f32) + c_headroom_i32 = arith.constant(_fp_headroom, type=i32) scale_cols_i32 = arith.constant(scale_cols, type=i32) inter_dim_i32 = arith.constant(inter_dim, type=i32) topk_i32 = arith.constant(topk, type=i32) n32_sort = scale_cols_i32 * c32_i32 - # Buffer resources in_rsrc = buffer_ops.create_buffer_resource(x, max_size=True) - out_rsrc = buffer_ops.create_buffer_resource(out_fp4, max_size=True) + out_rsrc = buffer_ops.create_buffer_resource(out_buf, max_size=True) scale_rsrc = buffer_ops.create_buffer_resource(out_scale_sorted, max_size=True) tid_rsrc = buffer_ops.create_buffer_resource(sorted_ids, max_size=True) nv_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) @@ -140,24 +152,24 @@ def silu_and_mul_fq_kernel( s_ok = arith.cmpi(CmpIPredicate.ult, slot_id, topk_i32) is_valid = arith.andi(row_in_range, arith.andi(t_ok, s_ok)) - def _f32_to_e2m1(qx_f32): - """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" - qx = qx_f32.bitcast(i32) - s = qx & c0x80000000_i32 - e = (qx >> c23_i32) & c0xFF_i32 - m = qx & c0x7FFFFF_i32 - adj_exp = arith.maxsi(c126_i32 - e, c0_i32) - m_denorm = (c0x400000_i32 | (m >> c1_i32)) >> adj_exp - is_denorm = arith.cmpi(CmpIPredicate.ult, e, c127_i32) - m = arith.select(is_denorm, m_denorm, m) - e = arith.maxsi(e - c126_i32, c0_i32) - combined = (e << c2_i32) | (m >> c21_i32) - rounded = (combined + c1_i32) >> c1_i32 - e2m1 = arith.minui(rounded, c7_i32) - return (s >> c28_i32) | e2m1 + if const_expr(_need_fp4): + + def _f32_to_e2m1(qx_f32): + qx = qx_f32.bitcast(i32) + s = qx & c0x80000000_i32 + e = (qx >> c23_i32) & c0xFF_i32 + m = qx & c0x7FFFFF_i32 + adj_exp = arith.maxsi(c126_i32 - e, c0_i32) + m_denorm = (c0x400000_i32 | (m >> c1_i32)) >> adj_exp + is_denorm = arith.cmpi(CmpIPredicate.ult, e, c127_i32) + m = arith.select(is_denorm, m_denorm, m) + e = arith.maxsi(e - c126_i32, c0_i32) + combined = (e << c2_i32) | (m >> c21_i32) + rounded = (combined + c1_i32) >> c1_i32 + e2m1 = arith.minui(rounded, c7_i32) + return (s >> c28_i32) | e2m1 thread_id = ArithValue(tid) - COLS_PER_ITER = BLOCK_THREADS * VEC for iter_idx in range_constexpr( @@ -174,42 +186,60 @@ def _f32_to_e2m1(qx_f32): _if_valid = scf.IfOp(is_valid, has_else=True) with ir.InsertionPoint(_if_valid.then_block): in_row = token_id * topk_i32 + slot_id - # FP4 output in token order: row = token_id * topk + slot_id - out_row_byte_base = in_row * arith.constant( - inter_dim // 2, type=i32 - ) - fp4_byte_off = out_row_byte_base + (col0 >> c1_i32) in_row_byte_base = in_row * arith.constant( inter_dim * 2 * elem_bytes_bf16, type=i32 ) - up_byte_offset = arith.constant( - inter_dim * elem_bytes_bf16, type=i32 - ) - gate_byte = in_row_byte_base + col0 * arith.constant( - elem_bytes_bf16, type=i32 - ) - up_byte = gate_byte + up_byte_offset - gate_dw = gate_byte >> c2_i32 - up_dw = up_byte >> c2_i32 vec_dw = VEC * elem_bytes_bf16 // 4 - gate_raw = buffer_ops.buffer_load( - in_rsrc, gate_dw, vec_width=vec_dw, dtype=i32 + if const_expr(gui_layout): + # Block-interleaved (block=16): + # [gate_0:16, up_0:16, gate_16:32, up_16:32, ...] + c16_i32 = arith.constant(16, type=i32) + block_idx = col0 >> c4_i32 + offset_in_blk = col0 & c15_i32 + gate_col = block_idx * c32_i32 + offset_in_blk + up_col = gate_col + c16_i32 + else: + # Gate-up separated: gate at col0, up at col0 + inter_dim + gate_col = col0 + up_col = col0 + inter_dim_i32 + + gate_byte = in_row_byte_base + gate_col * arith.constant( + elem_bytes_bf16, type=i32 ) - up_raw = buffer_ops.buffer_load( - in_rsrc, up_dw, vec_width=vec_dw, dtype=i32 + up_byte = in_row_byte_base + up_col * arith.constant( + elem_bytes_bf16, type=i32 ) + gate_dw = gate_byte >> c2_i32 + up_dw = up_byte >> c2_i32 vec_bf16_ty = T.vec(VEC, T.bf16) vec_f32_ty = T.vec(VEC, f32) + if const_expr(vec_dw == 1): vec1_i32_ty = T.vec(1, i32) - gate_vec = vector.from_elements(vec1_i32_ty, [gate_raw]) - up_vec = vector.from_elements(vec1_i32_ty, [up_raw]) - gate_bf16 = vector.bitcast(vec_bf16_ty, gate_vec) - up_bf16 = vector.bitcast(vec_bf16_ty, up_vec) + gate_raw = buffer_ops.buffer_load( + in_rsrc, gate_dw, vec_width=1, dtype=i32 + ) + up_raw = buffer_ops.buffer_load( + in_rsrc, up_dw, vec_width=1, dtype=i32 + ) + gate_bf16 = vector.bitcast( + vec_bf16_ty, + vector.from_elements(vec1_i32_ty, [gate_raw]), + ) + up_bf16 = vector.bitcast( + vec_bf16_ty, + vector.from_elements(vec1_i32_ty, [up_raw]), + ) else: + gate_raw = buffer_ops.buffer_load( + in_rsrc, gate_dw, vec_width=vec_dw, dtype=i32 + ) + up_raw = buffer_ops.buffer_load( + in_rsrc, up_dw, vec_width=vec_dw, dtype=i32 + ) gate_bf16 = vector.bitcast(vec_bf16_ty, gate_raw) up_bf16 = vector.bitcast(vec_bf16_ty, up_raw) gate_f32 = gate_bf16.extf(vec_f32_ty) @@ -234,116 +264,227 @@ def _f32_to_e2m1(qx_f32): ) act_vals.append(g * sig * u) - local_max = c0_f32 - for vi in range_constexpr(VEC): - abs_v = llvm.call_intrinsic( - f32, "llvm.fabs.f32", [act_vals[vi]], [], [] - ) - local_max = arith.maximumf(local_max, abs_v) - - for sh_dist in SHUFFLE_DISTS: - off = arith.constant(sh_dist, type=i32) - peer = local_max.shuffle_xor(off, c64_i32) - local_max = arith.maximumf(local_max, peer) - - max_i32_v = local_max.bitcast(i32) - max_rounded = (max_i32_v + c0x200000_i32) & c0xFF800000_i32 - exp_field = max_rounded >> c23_i32 - e8m0_biased = arith.maxsi(exp_field - c2_i32, c0_i32) - - quant_exp = c254_i32 - e8m0_biased - quant_scale = (quant_exp << c23_i32).bitcast(f32) - - fp4_vals = [] - for vi in range_constexpr(VEC): - scaled_v = act_vals[vi] * quant_scale - fp4_vals.append(_f32_to_e2m1(scaled_v)) - - packed_i32 = fp4_vals[0] | (fp4_vals[1] << c4_i32) - for k in range_constexpr(1, VEC // 2): - byte_k = fp4_vals[2 * k] | (fp4_vals[2 * k + 1] << c4_i32) - packed_i32 = packed_i32 | ( - byte_k << arith.constant(k * 8, type=i32) + if const_expr(_need_quant): + local_max = c0_f32 + for vi in range_constexpr(VEC): + abs_v = llvm.call_intrinsic( + f32, "llvm.fabs.f32", [act_vals[vi]], [], [] + ) + local_max = arith.maximumf(local_max, abs_v) + + for sh_dist in SHUFFLE_DISTS: + off = arith.constant(sh_dist, type=i32) + peer = local_max.shuffle_xor(off, c64_i32) + local_max = arith.maximumf(local_max, peer) + + max_i32_v = local_max.bitcast(i32) + max_rounded = (max_i32_v + c0x200000_i32) & c0xFF800000_i32 + exp_field = max_rounded >> c23_i32 + e8m0_biased = arith.maxsi(exp_field - c_headroom_i32, c0_i32) + quant_exp = c254_i32 - e8m0_biased + quant_scale = (quant_exp << c23_i32).bitcast(f32) + + if const_expr(_need_fp4): + out_row_byte_base = in_row * arith.constant( + inter_dim // 2, type=i32 + ) + out_byte_off = out_row_byte_base + (col0 >> c1_i32) + + fp4_vals = [] + for vi in range_constexpr(VEC): + scaled_v = act_vals[vi] * quant_scale + fp4_vals.append(_f32_to_e2m1(scaled_v)) + + packed_i32 = fp4_vals[0] | (fp4_vals[1] << c4_i32) + for k in range_constexpr(1, VEC // 2): + byte_k = fp4_vals[2 * k] | ( + fp4_vals[2 * k + 1] << c4_i32 + ) + packed_i32 = packed_i32 | ( + byte_k << arith.constant(k * 8, type=i32) + ) + + _pack_bytes = VEC // 2 + if const_expr(_pack_bytes == 1): + store_val = arith.TruncIOp(T.i8, packed_i32) + buffer_ops.buffer_store( + store_val, + out_rsrc, + out_byte_off, + offset_is_bytes=True, + ) + elif const_expr(_pack_bytes == 2): + store_val = arith.TruncIOp(T.i16, packed_i32) + buffer_ops.buffer_store( + store_val, + out_rsrc, + out_byte_off, + offset_is_bytes=True, + ) + else: + buffer_ops.buffer_store( + packed_i32, + out_rsrc, + out_byte_off, + offset_is_bytes=True, + ) + else: + out_row_byte_base = in_row * arith.constant( + inter_dim, type=i32 + ) + out_byte_off = out_row_byte_base + col0 + + scaled_vals = [] + for vi in range_constexpr(VEC): + scaled_vals.append(act_vals[vi] * quant_scale) + + if const_expr(VEC <= 4): + packed_i32 = c0_i32 + for _w in range_constexpr(VEC // 2): + packed_i32 = rocdl.cvt_pk_fp8_f32( + i32, + scaled_vals[2 * _w], + scaled_vals[2 * _w + 1], + packed_i32, + _w, + ) + if const_expr(VEC == 2): + store_val = arith.TruncIOp(T.i16, packed_i32) + buffer_ops.buffer_store( + store_val, + out_rsrc, + out_byte_off, + offset_is_bytes=True, + ) + else: + buffer_ops.buffer_store( + packed_i32, + out_rsrc, + out_byte_off, + offset_is_bytes=True, + ) + else: + for _wg in range_constexpr(VEC // 4): + _b = _wg * 4 + packed_w = c0_i32 + packed_w = rocdl.cvt_pk_fp8_f32( + i32, + scaled_vals[_b], + scaled_vals[_b + 1], + packed_w, + 0, + ) + packed_w = rocdl.cvt_pk_fp8_f32( + i32, + scaled_vals[_b + 2], + scaled_vals[_b + 3], + packed_w, + 1, + ) + word_off = out_byte_off + arith.constant( + _wg * 4, type=i32 + ) + buffer_ops.buffer_store( + packed_w, + out_rsrc, + word_off, + offset_is_bytes=True, + ) + + lane_in_blk = col0 & c31_i32 + _if_sw = scf.IfOp( + arith.cmpi(CmpIPredicate.eq, lane_in_blk, c0_i32) ) + with ir.InsertionPoint(_if_sw.then_block): + row_s = bid_i32 + col_s = col0 >> c5_i32 + d0 = row_s >> c5_i32 + d1 = (row_s >> c4_i32) & c1_i32 + d2 = row_s & c15_i32 + d3 = col_s >> c3_i32 + d4 = (col_s >> c2_i32) & c1_i32 + d5 = col_s & c3_i32 + s_byte_off = ( + d0 * n32_sort + + d3 * c256_i32 + + d5 * c64_i32 + + d2 * c4_i32 + + d4 * c2_i32 + + d1 + ) + e8m0_i8 = arith.TruncIOp(T.i8, e8m0_biased) + buffer_ops.buffer_store( + e8m0_i8, + scale_rsrc, + s_byte_off, + offset_is_bytes=True, + ) + scf.YieldOp([]) - _pack_bytes = VEC // 2 - if const_expr(_pack_bytes == 1): - store_val = arith.TruncIOp(T.i8, packed_i32) - buffer_ops.buffer_store( - store_val, out_rsrc, fp4_byte_off, offset_is_bytes=True - ) - elif const_expr(_pack_bytes == 2): - store_val = arith.TruncIOp(T.i16, packed_i32) - buffer_ops.buffer_store( - store_val, out_rsrc, fp4_byte_off, offset_is_bytes=True - ) else: - buffer_ops.buffer_store( - packed_i32, out_rsrc, fp4_byte_off, offset_is_bytes=True + out_row_byte_base = in_row * arith.constant( + inter_dim * elem_bytes_bf16, type=i32 ) - - lane_in_blk = col0 & c31_i32 - _if_sw = scf.IfOp(arith.cmpi(CmpIPredicate.eq, lane_in_blk, c0_i32)) - with ir.InsertionPoint(_if_sw.then_block): - row_s = bid_i32 - col_s = col0 >> c5_i32 - d0 = row_s >> c5_i32 - d1 = (row_s >> c4_i32) & c1_i32 - d2 = row_s & c15_i32 - d3 = col_s >> c3_i32 - d4 = (col_s >> c2_i32) & c1_i32 - d5 = col_s & c3_i32 - s_byte_off = ( - d0 * n32_sort - + d3 * c256_i32 - + d5 * c64_i32 - + d2 * c4_i32 - + d4 * c2_i32 - + d1 + out_byte_off = out_row_byte_base + col0 * arith.constant( + elem_bytes_bf16, type=i32 ) - e8m0_i8 = arith.TruncIOp(T.i8, e8m0_biased) - buffer_ops.buffer_store( - e8m0_i8, scale_rsrc, s_byte_off, offset_is_bytes=True + out_dw_off = out_byte_off >> c2_i32 + _vec_f32_ty = T.vec(VEC, f32) + _vec_bf16_ty = T.vec(VEC, T.bf16) + act_f32_vec = vector.from_elements(_vec_f32_ty, act_vals) + act_bf16_vec = act_f32_vec.truncf(_vec_bf16_ty) + act_i32 = vector.bitcast( + T.vec(VEC * elem_bytes_bf16 // 4, i32), act_bf16_vec ) - scf.YieldOp([]) + vec_dw_out = VEC * elem_bytes_bf16 // 4 + if const_expr(vec_dw_out == 1): + store_scalar = vector.extract( + act_i32, static_position=[0], dynamic_position=[] + ) + buffer_ops.buffer_store(store_scalar, out_rsrc, out_dw_off) + else: + buffer_ops.buffer_store(act_i32, out_rsrc, out_dw_off) + scf.YieldOp([]) with ir.InsertionPoint(_if_valid.else_block): - # Padding row: skip FP4 write (stage2 gather-loads by token_id, - # so padding rows are never read). Only write zero scale. - lane_in_blk_p = col0 & c31_i32 - _if_sw_p = scf.IfOp( - arith.cmpi(CmpIPredicate.eq, lane_in_blk_p, c0_i32) - ) - with ir.InsertionPoint(_if_sw_p.then_block): - row_s_p = bid_i32 - col_s_p = col0 >> c5_i32 - d0_p = row_s_p >> c5_i32 - d1_p = (row_s_p >> c4_i32) & c1_i32 - d2_p = row_s_p & c15_i32 - d3_p = col_s_p >> c3_i32 - d4_p = (col_s_p >> c2_i32) & c1_i32 - d5_p = col_s_p & c3_i32 - s_byte_off_p = ( - d0_p * n32_sort - + d3_p * c256_i32 - + d5_p * c64_i32 - + d2_p * c4_i32 - + d4_p * c2_i32 - + d1_p - ) - c0_i8 = arith.TruncIOp(T.i8, c0_i32) - buffer_ops.buffer_store( - c0_i8, scale_rsrc, s_byte_off_p, offset_is_bytes=True + if const_expr(_need_quant): + lane_in_blk_p = col0 & c31_i32 + _if_sw_p = scf.IfOp( + arith.cmpi(CmpIPredicate.eq, lane_in_blk_p, c0_i32) ) - scf.YieldOp([]) + with ir.InsertionPoint(_if_sw_p.then_block): + row_s_p = bid_i32 + col_s_p = col0 >> c5_i32 + d0_p = row_s_p >> c5_i32 + d1_p = (row_s_p >> c4_i32) & c1_i32 + d2_p = row_s_p & c15_i32 + d3_p = col_s_p >> c3_i32 + d4_p = (col_s_p >> c2_i32) & c1_i32 + d5_p = col_s_p & c3_i32 + s_byte_off_p = ( + d0_p * n32_sort + + d3_p * c256_i32 + + d5_p * c64_i32 + + d2_p * c4_i32 + + d4_p * c2_i32 + + d1_p + ) + c0_i8 = arith.TruncIOp(T.i8, c0_i32) + buffer_ops.buffer_store( + c0_i8, + scale_rsrc, + s_byte_off_p, + offset_is_bytes=True, + ) + scf.YieldOp([]) scf.YieldOp([]) scf.YieldOp([]) @flyc.jit def launch_silu_and_mul_fq( x: fx.Tensor, - out_fp4: fx.Tensor, + out_buf: fx.Tensor, out_scale_sorted: fx.Tensor, sorted_ids: fx.Tensor, num_valid_ids: fx.Tensor, @@ -357,7 +498,7 @@ def launch_silu_and_mul_fq( idx_rows = arith.index_cast(T.index, num_sorted_rows) launcher = silu_and_mul_fq_kernel( - x, out_fp4, out_scale_sorted, sorted_ids, num_valid_ids, token_num + x, out_buf, out_scale_sorted, sorted_ids, num_valid_ids, token_num ) launcher.launch( grid=(idx_rows, 1, 1), diff --git a/aiter/ops/flydsl/moe_kernels.py b/aiter/ops/flydsl/moe_kernels.py index 9bb63694b5..0d33bf687c 100644 --- a/aiter/ops/flydsl/moe_kernels.py +++ b/aiter/ops/flydsl/moe_kernels.py @@ -9,12 +9,11 @@ from typing import Dict, Optional from aiter.utility import dtypes -import flydsl.compiler as flyc import torch _KERNEL_PARAMS: Dict[str, Dict] = {} -_SUFFIX_RE = re.compile(r"(?P_fq)?(?:_sbm(?P\d+))?$") +_SUFFIX_RE = re.compile(r"(?P_fp4)?(?P_fp8)?(?:_sbm(?P\d+))?$") def flydsl_kernel_name( @@ -27,21 +26,18 @@ def flydsl_kernel_name( tile_k: int, mode: str = "", sort_block_m: int = 0, - fuse_fp4_quant: bool = False, ) -> str: - """Construct kernel name: ``flydsl_moe{stage}_a{a}_w{b}_{out}_t{M}x{N}x{K}[_{mode}][_fq][_sbm{S}]``.""" + """Construct kernel name: ``flydsl_moe{stage}_a{a}_w{b}_{out}_t{M}x{N}x{K}[_{mode}][_sbm{S}]``.""" name = f"flydsl_moe{stage}_a{a_dtype}_w{b_dtype}_{out_dtype}_t{tile_m}x{tile_n}x{tile_k}" if mode: name += f"_{mode}" - if fuse_fp4_quant: - name += "_fq" if sort_block_m > 0 and sort_block_m != tile_m: name += f"_sbm{sort_block_m}" return name def get_flydsl_kernel_params(name: str) -> Optional[Dict]: - """Lookup kernel params by name. Strips ``_fq`` / ``_sbm{N}`` suffixes transparently.""" + """Lookup kernel params by name. Strips ``_fp4`` / ``_fp8`` / ``_sbm{N}`` suffixes transparently.""" params = _KERNEL_PARAMS.get(name) if params is not None: return params @@ -51,8 +47,11 @@ def get_flydsl_kernel_params(name: str) -> Optional[Dict]: params = _KERNEL_PARAMS.get(base_name) if params is not None: extra: Dict = {} - if m.group("fq"): - extra["fuse_fp4_quant"] = True + if m.group("fp4"): + extra["out_dtype"] = "fp4" + if m.group("fp8"): + extra["out_dtype"] = "fp8" + extra["a_scale_one"] = True if m.group("sbm") is not None: extra["sort_block_m"] = int(m.group("sbm")) return {**params, **extra} @@ -64,52 +63,71 @@ def get_flydsl_stage1_kernels( ) -> Dict[str, Dict]: """Return {kernelName: params} for all supported stage1 configs.""" kernels = {} - is_fp4 = b_dtype == "fp4" + is_fp4_a = a_dtype == "fp4" + is_fp4_b = b_dtype == "fp4" - tile_ns = [32, 64, 128] if is_fp4 else [128] + tile_ns = [32, 64, 128] if is_fp4_b else [128] tile_ks = [256] - tile_ms = [16, 32, 64, 128] + tile_ms = [32, 64, 128] + waves_per_eus = [1, 2, 3, 4] k_batches = [1, 2, 4, 7, 14] b_nts = [0, 2] + xcd_swizzles = [0, 4] for tm in tile_ms: - if tm in [16, 32]: + if tm == 32: tile_ns = [32, 64, 128] else: - tile_ns = [64, 128] + tile_ns = [64, 128] if is_fp4_a else [128, 256] for tn in tile_ns: for tk in tile_ks: for wpe in waves_per_eus: - for kb in k_batches if wpe == 3 else [1]: - gate_onlys = [False, True] if kb > 1 else [False] + for kb in k_batches if wpe == 3 and tm == 32 and is_fp4_a else [1]: for bnt in b_nts: + gate_onlys = ( + [False, True] if kb > 1 and is_fp4_a else [False] + ) for go in gate_onlys: - name = flydsl_kernel_name( - 1, a_dtype, b_dtype, out_dtype, tm, tn, tk - ) - if wpe != 1: - name += f"_w{wpe}" - if kb != 1: - name += f"_kb{kb}" - if bnt != 2: - name += f"_bnt{bnt}" - if go: - name += "_go" - kernels[name] = { - "stage": 1, - "a_dtype": a_dtype, - "b_dtype": b_dtype, - "out_dtype": out_dtype, - "tile_m": tm, - "tile_n": tn, - "tile_k": tk, - "MPerBlock": tm, - "waves_per_eu": wpe, - "k_batch": kb, - "b_nt": bnt, - "gate_only": go, - } + for xcd in xcd_swizzles: + name = flydsl_kernel_name( + 1, a_dtype, b_dtype, out_dtype, tm, tn, tk + ) + if wpe != 1: + name += f"_w{wpe}" + if kb != 1: + name += f"_kb{kb}" + if bnt != 2: + name += f"_bnt{bnt}" + if go: + name += "_go" + if a_dtype == "fp8": + name += "_gui" + if xcd > 0: + name += f"_xcd{xcd}" + kernels[name] = { + "stage": 1, + "a_dtype": a_dtype, + "b_dtype": b_dtype, + "out_dtype": out_dtype, + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "MPerBlock": tm, + "waves_per_eu": wpe, + "k_batch": kb, + "b_nt": bnt, + "gate_mode": ( + "mock_gate_only" + if go + else ( + "interleave" + if a_dtype == "fp8" + else "separated" + ) + ), + "xcd_swizzle": xcd, + } return kernels @@ -124,30 +142,41 @@ def get_flydsl_stage2_kernels( tile_ms = [16, 32, 64, 128] if is_fp4 else [32, 64, 128] modes = ["atomic", "reduce"] + b_nts = [0, 2] + + xcd_swizzles = [0, 4] + for tm in tile_ms: for tn in tile_ns: for tk in tile_ks: for mode in modes: - base_name = flydsl_kernel_name( - 2, a_dtype, b_dtype, out_dtype, tm, tn, tk, mode - ) - base_params = { - "stage": 2, - "a_dtype": a_dtype, - "b_dtype": b_dtype, - "out_dtype": out_dtype, - "tile_m": tm, - "tile_n": tn, - "tile_k": tk, - "mode": mode, - "MPerBlock": tm, - } - kernels[base_name] = base_params - # Persistent variant: round-robin over M tiles, grid_y=cu_num. - kernels[base_name + "_persist"] = { - **base_params, - "persist": True, - } + for bnt in b_nts: + for xcd in xcd_swizzles: + base_name = flydsl_kernel_name( + 2, a_dtype, b_dtype, out_dtype, tm, tn, tk, mode + ) + if bnt != 0: + base_name += f"_bnt{bnt}" + if xcd > 0: + base_name += f"_xcd{xcd}" + base_params = { + "stage": 2, + "a_dtype": a_dtype, + "b_dtype": b_dtype, + "out_dtype": out_dtype, + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "mode": mode, + "MPerBlock": tm, + "b_nt": bnt, + "xcd_swizzle": xcd, + } + kernels[base_name] = base_params + kernels[base_name + "_persist"] = { + **base_params, + "persist": True, + } return kernels @@ -177,17 +206,20 @@ def compile_flydsl_moe_stage1( out_dtype: str, act: str = "silu", persist_m: int = 1, - fuse_fp4_quant: bool = False, - fuse_sort_scale: bool = False, use_async_copy: bool = False, k_batch: int = 1, waves_per_eu: int = 3, b_nt: int = 2, - gate_only: bool = False, + gate_mode: str = "separated", + model_dim_pad: int = 0, + inter_dim_pad: int = 0, + enable_bias: bool = False, + a_scale_one: bool = False, + xcd_swizzle: int = 0, ): """Compile stage1 kernel (cached via underlying lru_cache).""" if b_dtype == "fp4": - from .kernels.mixed_moe_gemm_2stage import compile_mixed_moe_gemm1 + from .kernels.mixed_moe_gemm_2stage import compile_mixed_moe_gemm1, GateMode return compile_mixed_moe_gemm1( model_dim=model_dim, @@ -203,13 +235,16 @@ def compile_flydsl_moe_stage1( out_dtype=out_dtype, act=act, persist_m=persist_m, - fuse_fp4_quant=fuse_fp4_quant, - fuse_sort_scale=fuse_sort_scale, use_async_copy=use_async_copy, k_batch=k_batch, waves_per_eu=waves_per_eu, b_nt=b_nt, - gate_only=gate_only, + gate_mode=GateMode(gate_mode), + model_dim_pad=model_dim_pad, + inter_dim_pad=inter_dim_pad, + enable_bias=enable_bias, + a_scale_one=a_scale_one, + xcd_swizzle=xcd_swizzle, ) else: from .kernels.moe_gemm_2stage import compile_moe_gemm1 @@ -243,6 +278,11 @@ def compile_flydsl_moe_stage2( accumulate: bool = True, persist_m: int = 1, sort_block_m: int = 0, + b_nt: int = 0, + model_dim_pad: int = 0, + inter_dim_pad: int = 0, + xcd_swizzle: int = 0, + enable_bias: bool = False, ): """Compile stage2 kernel (cached via underlying lru_cache).""" if b_dtype == "fp4": @@ -263,6 +303,11 @@ def compile_flydsl_moe_stage2( accumulate=accumulate, persist_m=persist_m, sort_block_m=sort_block_m, + b_nt=b_nt, + model_dim_pad=model_dim_pad, + inter_dim_pad=inter_dim_pad, + xcd_swizzle=xcd_swizzle, + enable_bias=enable_bias, ) else: from .kernels.moe_gemm_2stage import compile_moe_gemm2 @@ -313,8 +358,10 @@ def _s1_args_fp4( k_in, size_expert_ids_in, dev, + bias=None, ): empty_f32 = torch.empty(0, device=dev, dtype=torch.float32) + _bias = bias if bias is not None else empty_f32 return ( _view_safe(out), _view_safe(a), @@ -325,7 +372,7 @@ def _s1_args_fp4( sorted_expert_ids, sorted_weights, num_valid_ids, - empty_f32, + _bias, out_scale_sorted, token_num, n_in, @@ -383,8 +430,13 @@ def _s2_args_fp4( k_in, blocks, dev, + bias=None, ): - empty_f32 = torch.empty(0, device=dev, dtype=torch.float32) + _bias = ( + bias.view(-1) + if bias is not None + else torch.empty(0, device=dev, dtype=torch.float32) + ) return ( _view_safe(target), _view_safe(a), @@ -395,7 +447,7 @@ def _s2_args_fp4( sorted_expert_ids, sorted_weights, num_valid_ids, - empty_f32, + _bias, token_num, n_in, k_in, @@ -438,23 +490,37 @@ def _s2_args_std( def _run_compiled(exe, args): - """First call: ``flyc.compile(exe, *args)`` compiles **and** executes the kernel. - Subsequent calls: fast dispatch via the cached ``CompiledFunction``. + """Call the JitFunction with the given args. + JitFunction.__call__ handles compilation caching internally. """ - cf = getattr(exe, "_aiter_cf", None) - if cf is None: - cf = flyc.compile(exe, *args) - exe._aiter_cf = cf - else: - cf(*args) + try: + exe(*args) + except Exception: + # JitFunction.__call__ leaks ir.Context on compilation failure, + # causing all subsequent JitFunction calls to take a wrong code path + # (self.func(*args) without CompilationContext → gpu_module_body error). + # Clean up leaked contexts to isolate failures. + try: + from flydsl._mlir import ir + + while ir.Context.current is not None: + ir.Context.current.__exit__(None, None, None) + except Exception: + pass + raise @functools.cache -def _get_compiled_silu_fq(inter_dim: int, topk: int): - """Compile and cache the fused silu_and_mul + mxfp4 quant + scale-sort kernel.""" +def _get_compiled_silu_fused( + inter_dim: int, + topk: int, + quant_mode: str = "fp4", + gui_layout: bool = False, +): + """Compile and cache the fused silu_and_mul + quant + scale-sort kernel.""" from aiter.ops.flydsl.kernels.silu_and_mul_fq import build_silu_and_mul_fq_module - return build_silu_and_mul_fq_module(inter_dim, topk) + return build_silu_and_mul_fq_module(inter_dim, topk, quant_mode, gui_layout) # Public API @@ -480,33 +546,36 @@ def flydsl_moe_stage1( a1_scale: Optional[torch.Tensor] = None, sorted_weights: Optional[torch.Tensor] = None, persist_m: int = 0, - fuse_fp4_quant: bool = False, - fuse_sort_scale: bool = False, use_async_copy: bool = False, k_batch: int = 1, waves_per_eu: int = 3, - b_nt: int = 2, - gate_only: bool = False, + b_nt: int = 0, + gate_mode: str = "separated", + model_dim_pad: int = 0, + inter_dim_pad: int = 0, + bias: Optional[torch.Tensor] = None, + a_scale_one: bool = False, + xcd_swizzle: int = 0, ): """Fused gate+up GEMM (MOE stage1). a: (token_num, model_dim), w1: (E, 2*inter_dim, model_dim) pre-shuffled. + model_dim and inter_dim INCLUDE padding (model_dim_pad, inter_dim_pad). + bias: optional (E, 2*inter_dim) f32 bias added before activation. For fp4 stage1, `w1`/`w1_scale` must use the same preshuffle layout as - `shuffle_weight(..., (16, 16))` and `e8m0_shuffle(...)`. + `shuffle_weight_a16w4(w1, 16, True)` and `shuffle_scale_a16w4(w1_scale, E, True)`. - When fuse_sort_scale=True, the kernel writes e8m0 scales in sorted tiled - layout directly, avoiding a separate moe_mxfp4_sort call. + When fuse_quant=True, the kernel fuses quantization (fp4/fp8, inferred from + out_dtype) and writes e8m0 scales in sorted tiled layout directly. When k_batch>1 (split-K), the kernel outputs gate/up partials via atomic add into a zeroed buffer, then silu_and_mul fuses activation + reduction. - When gate_only=True (requires k_batch>1), each workgroup computes only - one B-tile stream (no gate/up interleaving). The grid X doubles so - that by_n naturally covers both gate and up regions. + gate_mode controls the gate/up computation strategy (see GateMode enum). Returns: Basic: out - fuse_sort_scale: (out, out_scale_sorted) + fuse_quant: (out, out_scale_sorted) """ token_num = a.shape[0] E = w1.shape[0] @@ -516,20 +585,33 @@ def flydsl_moe_stage1( if a_dtype == "fp4": model_dim = model_dim * 2 - torch_out_dtype = ( - dtypes.fp4x2 - if fuse_fp4_quant - else dtypes.bf16 if out_dtype == "bf16" else dtypes.fp16 - ) + _need_fp4 = out_dtype == "fp4" + _need_fp8 = out_dtype == "fp8" + _fuse_any_quant = _need_fp4 or _need_fp8 + _base_out_dtype = "bf16" if _fuse_any_quant else out_dtype + + if _need_fp4: + torch_out_dtype = dtypes.fp4x2 + elif _need_fp8: + torch_out_dtype = dtypes.fp8 + else: + torch_out_dtype = dtypes.bf16 if out_dtype == "bf16" else dtypes.fp16 _is_splitk = k_batch > 1 + gate_up_interleave = gate_mode == "interleave" dev = a.device - _splitk_fq = _is_splitk and fuse_fp4_quant + _splitk_fp4 = _is_splitk and _need_fp4 + _gui_sk = gate_up_interleave and _is_splitk + _gui_sk_fused = _gui_sk and _fuse_any_quant if out is None: - if fuse_fp4_quant: + if _need_fp4 or (_gui_sk_fused and _need_fp4): + out = torch.empty( + (token_num, topk, inter_dim // 2), dtype=dtypes.fp4x2, device=dev + ) + elif _need_fp8 or (_gui_sk_fused and _need_fp8): out = torch.empty( - (token_num, topk, inter_dim // 2), dtype=torch_out_dtype, device=dev + (token_num, topk, inter_dim), dtype=dtypes.fp8, device=dev ) else: out = torch.empty( @@ -537,7 +619,7 @@ def flydsl_moe_stage1( ) if _is_splitk: - torch_tmp_out_dtype = dtypes.bf16 if out_dtype == "bf16" else dtypes.fp16 + torch_tmp_out_dtype = dtypes.bf16 if _base_out_dtype == "bf16" else dtypes.fp16 tmp_out = torch.zeros( (token_num, topk, inter_dim * 2), dtype=torch_tmp_out_dtype, device=dev ) @@ -556,10 +638,10 @@ def flydsl_moe_stage1( else torch.empty(0, device=dev, dtype=torch.float32) ) - _need_quant = fuse_fp4_quant or _splitk_fq - _need_sort = _need_quant and (fuse_sort_scale or _splitk_fq) + _need_quant = _fuse_any_quant or _splitk_fp4 or _gui_sk_fused + _need_sort = _need_quant - _sort_block_m = max(32, tile_m) + _sort_block_m = tile_m _all_blks = sorted_expert_ids.shape[0] _dense_blks = ( min(token_num * topk * _sort_block_m, sorted_token_ids.shape[0]) @@ -584,8 +666,7 @@ def flydsl_moe_stage1( # split-K GEMM kernel does not fuse quant; the fused silu_and_mul_fq kernel # handles activation + quant + scale-sort after the GEMM completes. - _gemm_fq = fuse_fp4_quant and not _is_splitk - _gemm_fss = fuse_sort_scale and not _is_splitk + _gemm_out_dtype = _base_out_dtype if _is_splitk else out_dtype _kernel_out = tmp_out if _is_splitk else out is_fp4 = b_dtype == "fp4" @@ -609,6 +690,7 @@ def flydsl_moe_stage1( _k_in, _grid_y, dev, + bias=bias.view(-1) if bias is not None else torch.empty(0, device=dev), ) else: args = _s1_args_std( @@ -638,24 +720,62 @@ def flydsl_moe_stage1( doweight_stage1=(sorted_weights is not None), a_dtype=a_dtype, b_dtype=b_dtype, - out_dtype=out_dtype, + out_dtype=_gemm_out_dtype, act=act, persist_m=_persist_m, - fuse_fp4_quant=_gemm_fq, - fuse_sort_scale=_gemm_fss, use_async_copy=use_async_copy, k_batch=k_batch, waves_per_eu=waves_per_eu, b_nt=b_nt, - gate_only=gate_only, + gate_mode=gate_mode, + model_dim_pad=model_dim_pad, + inter_dim_pad=inter_dim_pad, + enable_bias=(bias is not None), + a_scale_one=a_scale_one, + xcd_swizzle=xcd_swizzle, ) _run_compiled(exe, args) - if _splitk_fq: - _silu_fq = _get_compiled_silu_fq(inter_dim, topk) - num_sorted_rows = sorted_token_ids.shape[0] + num_sorted_rows = sorted_token_ids.shape[0] + if _gui_sk_fused: + _quant_mode = "fp4" if _need_fp4 else "fp8" + _silu_fused_k = _get_compiled_silu_fused( + inter_dim, topk, _quant_mode, gui_layout=True + ) + _run_compiled( + _silu_fused_k, + ( + tmp_out.view(-1, inter_dim * 2), + out.view(-1).view(torch.uint8), + out_scale_sorted_flat, + sorted_token_ids, + num_valid_ids, + token_num, + num_sorted_rows, + torch.cuda.current_stream(), + ), + ) + elif _gui_sk: + _silu_fused_k = _get_compiled_silu_fused( + inter_dim, topk, "none", gui_layout=True + ) + _run_compiled( + _silu_fused_k, + ( + tmp_out.view(-1, inter_dim * 2), + out.view(-1).view(torch.uint8), + out_scale_sorted_flat, + sorted_token_ids, + num_valid_ids, + token_num, + num_sorted_rows, + torch.cuda.current_stream(), + ), + ) + elif _splitk_fp4: + _silu_fused_k = _get_compiled_silu_fused(inter_dim, topk) _run_compiled( - _silu_fq, + _silu_fused_k, ( tmp_out.view(-1, inter_dim * 2), out.view(-1).view(torch.uint8), @@ -672,7 +792,7 @@ def flydsl_moe_stage1( silu_and_mul(out.view(-1, inter_dim), tmp_out.view(-1, inter_dim * 2)) - if fuse_fp4_quant: + if _fuse_any_quant and _need_sort: from aiter.utility.dtypes import fp8_e8m0 out_scale_sorted = out_scale_sorted_flat.view(fp8_e8m0).view( @@ -704,11 +824,17 @@ def flydsl_moe_stage2( sorted_weights: Optional[torch.Tensor] = None, sort_block_m: int = 0, persist: Optional[bool] = None, + b_nt: int = 0, + model_dim_pad: int = 0, + inter_dim_pad: int = 0, + xcd_swizzle: int = 0, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Down-projection GEMM (MOE stage2). Supports atomic/reduce modes. a: (token_num, topk, inter_dim), w1: (E, model_dim, inter_dim) pre-shuffled. Returns (token_num, model_dim). + bias: optional (E, model_dim) f32 bias added after GEMM. sort_block_m: block_size used by moe_sorting / stage1. When 0 (default), assumed equal to tile_m. When set, stage2 can use a different tile_m @@ -760,6 +886,9 @@ def flydsl_moe_stage2( else: _persist_m = -1 if m_blocks > 256 else 1 + if a_dtype == "fp8": + _persist_m = 1 + is_fp4 = b_dtype == "fp4" _n_in = model_dim _k_in = inter_dim @@ -788,6 +917,7 @@ def flydsl_moe_stage2( _k_in, m_blocks, dev, + bias=bias, ) else: args = _s2_args_std( @@ -821,6 +951,11 @@ def flydsl_moe_stage2( accumulate=accumulate, persist_m=_persist_m, sort_block_m=sort_block_m, + b_nt=b_nt, + model_dim_pad=model_dim_pad, + inter_dim_pad=inter_dim_pad, + xcd_swizzle=xcd_swizzle, + enable_bias=(bias is not None), ) _run_compiled(exe, args) diff --git a/aiter/utility/mp_tuner.py b/aiter/utility/mp_tuner.py index 46e6ab065b..04c7330600 100644 --- a/aiter/utility/mp_tuner.py +++ b/aiter/utility/mp_tuner.py @@ -28,6 +28,7 @@ def worker( atol=1e-2, printLog=False, tol_err_ratio=0.05, + compare_fn=None, ): from aiter.test_common import run_perftest @@ -77,18 +78,26 @@ def worker( if isinstance(ref[i], torch.Tensor): if res[i].shape != ref[i].shape: res[i] = res[i].view(-1)[: ref[i].numel()].view(ref[i].shape) - if ref[i].dtype.itemsize == 1: - ref[i] = ref[i].view(torch.uint8).to(dtypes.fp32) - res[i] = res[i].view(torch.uint8).to(dtypes.fp32) - err_ratio = checkAllclose( - ref[i], - res[i], - atol=atol, - rtol=rtol, - tol_err_ratio=tol_err_ratio, - printLog=printLog, - msg=f"info:{info} res[{i}] ", - ) + if compare_fn is not None: + err_ratio = compare_fn( + ref[i], + res[i], + msg=f"info:{info} res[{i}] ", + printLog=printLog, + ) + else: + if ref[i].dtype.itemsize == 1: + ref[i] = ref[i].view(torch.uint8).to(dtypes.fp32) + res[i] = res[i].view(torch.uint8).to(dtypes.fp32) + err_ratio = checkAllclose( + ref[i], + res[i], + atol=atol, + rtol=rtol, + tol_err_ratio=tol_err_ratio, + printLog=printLog, + msg=f"info:{info} res[{i}] ", + ) max_err_ratio = max(max_err_ratio, err_ratio) except RuntimeError as e: if "CUDA" in str(e) or "HIP" in str(e) or "out of memory" in str(e).lower(): @@ -218,9 +227,11 @@ def work_group(GPUIDMap, fast_mode, err_ratio, in_data, tasks, verbose=False): torch.cuda.synchronize() _prev_ref_key = _cur_key - # Extract rtol, atol from rest if available, otherwise use defaults + # Extract rtol, atol from rest if available, otherwise use defaults. + # Optional rest[2]: custom compare callable (e.g. cosine diff for a8w4). rtol = rest[0] if len(rest) > 0 else 1e-2 atol = rest[1] if len(rest) > 1 else 1e-2 + compare_fn = rest[2] if len(rest) > 2 and callable(rest[2]) else None work_args = ( gpu_id, @@ -233,6 +244,7 @@ def work_group(GPUIDMap, fast_mode, err_ratio, in_data, tasks, verbose=False): atol, verbose, # Use the verbose from work_group parameter err_ratio, # Use the err_ratio from work_group parameter + compare_fn, ) # Run worker with explicit GPU ID @@ -419,7 +431,7 @@ def add_dummy_result(k, results_list): elapsed = time.time() - task_start_times[k] if verbose: print( - f"[Done] Task {k}/{len(rets)-1} completed in {elapsed:.1f}s ({len(result_dict)}/{len(rets)} done)" + f"[Done] Task {k}/{len(rets) - 1} completed in {elapsed:.1f}s ({len(result_dict)}/{len(rets)} done)" ) except MPTimeoutError: @@ -575,14 +587,14 @@ def add_dummy_result(k, results_list): timeout_count = sum(1 for _, reason in failed_tasks if reason == "timeout") crash_count = len(failed_tasks) - timeout_count summary = ( - f"\n{'='*60}\n" + f"\n{'=' * 60}\n" f"Tuning Summary:\n" f" Total tasks: {len(rets)}\n" f" Successful: {len(rets) - len(failed_tasks)}\n" f" Failed: {len(failed_tasks)}\n" f" - Timeouts (GPU hang): {timeout_count}\n" f" - Crashes (memory fault): {crash_count}\n" - f"{'='*60}" + f"{'=' * 60}" ) logger.warning(summary) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py index ac0540e768..292b764528 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py @@ -20,6 +20,8 @@ torch_moe_stage1, torch_moe_stage2, torch_moe, + cktile_moe_stage1, + cktile_moe_stage2, ) from aiter import ck_moe_stage1_fwd, ck_moe_stage2_fwd, dtype2str_dict from aiter.ops.shuffle import ( @@ -60,10 +62,56 @@ FLYDSL_FALLBACK_TAG = "flydsl_fallback" +TUNE_MOE_EXPERT_BALANCE = ( + os.environ.get("TUNE_MOE_EXPERT_BALANCE", "False").lower() == "true" +) +COS_DIFF_THRESHOLD = 1e-1 + + +def torch_dynamic_mxfp8_quant(x: torch.Tensor): + """MXFP8 quantization (e4m3fn + e8m0 block scale, block=32). + + Same numerics as ``aiter/bench_stage2_a8w4.py`` for a8w4 activations. + """ + BLOCK = 32 + orig_shape = x.shape + x_f32 = x.reshape(-1, x.shape[-1] // BLOCK, BLOCK).float() + + amax, _ = torch.max(torch.abs(x_f32), dim=-1) + amax_i32 = amax.view(torch.int32) + amax_rounded = (amax_i32 + 0x200000) & 0xFF800000 + exp_field = (amax_rounded >> 23) & 0xFF + + e8m0_biased = torch.clamp(exp_field - 8, min=0) + quant_exp = 254 - e8m0_biased + quant_scale = (quant_exp << 23).view(torch.float32) + + scaled = x_f32 * quant_scale.unsqueeze(-1) + fp8_vals = scaled.to(torch.float8_e4m3fn) + fp8_bytes = fp8_vals.view(torch.uint8) + + e8m0_bytes = e8m0_biased.to(torch.uint8).view(dtypes.fp8_e8m0) + return fp8_bytes.view(*orig_shape), e8m0_bytes.view( + *orig_shape[:-1], orig_shape[-1] // BLOCK + ) + + +def cosine_diff_compare(ref, res, msg="", printLog=True): + from aiter import logger + + x = ref.double().flatten() + y = res.double().flatten() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + if printLog: + if cos_diff < COS_DIFF_THRESHOLD: + logger.info(f"{msg}[cosine_diff={cos_diff:.6f} \033[32mpassed~\033[0m]") + else: + logger.info(f"{msg}[cosine_diff={cos_diff:.6f} \033[31mfailed!\033[0m]") + return cos_diff if cos_diff >= COS_DIFF_THRESHOLD else 0.0 -class FmoeTuner(TunerCommon): +class FmoeTuner(TunerCommon): ARG_DEFAULTS = { **TunerCommon.ARG_DEFAULTS, "verbose": False, @@ -272,6 +320,87 @@ def ck_moe_stage2_fwd_out( act_type, ) + @staticmethod + def cktile_moe_stage1_out( + a1_fp8, + w1_qt_shffle_ck, + w2_qt_shffle_ck, + sorted_ids, + sorted_expert_ids, + sorted_weights, + num_valid_ids, + w1_scale_aiter, + bias, + dtype, + topk, + blockM, + act_type, + ): + M_sorted = sorted_ids.shape[0] + model_dim = a1_fp8.shape[1] + a1_scale = torch.ones( + (M_sorted, model_dim // 32), dtype=dtypes.fp8_e8m0, device=a1_fp8.device + ) + return cktile_moe_stage1( + a1_fp8, + w1_qt_shffle_ck, + w2_qt_shffle_ck, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + None, + topk, + blockM, + a1_scale=a1_scale, + w1_scale=w1_scale_aiter.view(dtypes.fp8_e8m0), + sorted_weights=sorted_weights, + bias1=bias, + activation=act_type, + split_k=1, + dtype=dtype, + ) + + @staticmethod + def cktile_moe_stage2_out( + a2_qt, + w1_qt_shffle_ck, + w2_qt_shffle_ck, + sorted_ids, + sorted_expert_ids, + sorted_weights, + num_valid_ids, + w2_scale_aiter, + a2_scale_sort, + bias, + dtype, + topk, + blockM, + act_type, + ): + token_num = a2_qt.shape[0] + model_dim = w2_qt_shffle_ck.shape[1] + out = torch.zeros( + (token_num, model_dim), + dtype=dtype, + device=a2_qt.device, + ) + return cktile_moe_stage2( + a2_qt, + w1_qt_shffle_ck, + w2_qt_shffle_ck, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + w2_scale=w2_scale_aiter.view(dtypes.fp8_e8m0), + a2_scale=a2_scale_sort, + block_m=blockM, + activation=act_type, + sorted_weights=sorted_weights, + bias2=bias, + ) + @staticmethod def run_flydsl_stage1_out( a1_qt, @@ -282,19 +411,22 @@ def run_flydsl_stage1_out( num_valid_ids, w1_scale_aiter, a1_scale, + bias, dtype, topk, kparams, blockM, + q_dtype_a, q_type, act_type, ): act = "swiglu" if act_type == ActivationType.Swiglu else "silu" - fuse_fq = kparams.get("fuse_fp4_quant", False) + a_scale_one = kparams.get("a_scale_one", False) + _out_dtype = kparams["out_dtype"] token_num = a1_qt.shape[0] inter_dim = w1_qt_shffle_ck.shape[1] // 2 result = flydsl_moe_stage1( - a=a1_qt, + a=a1_qt.to(dtypes.fp8) if q_dtype_a == dtypes.fp8 else a1_qt, w1=w1_qt_shffle_ck, sorted_token_ids=sorted_ids, sorted_expert_ids=sorted_expert_ids, @@ -305,7 +437,7 @@ def run_flydsl_stage1_out( tile_k=kparams["tile_k"], a_dtype=kparams["a_dtype"], b_dtype=kparams["b_dtype"], - out_dtype=kparams["out_dtype"], + out_dtype=_out_dtype, act=act, w1_scale=w1_scale_aiter, a1_scale=a1_scale, @@ -314,15 +446,20 @@ def run_flydsl_stage1_out( k_batch=kparams.get("k_batch", 1), waves_per_eu=kparams.get("waves_per_eu", 3), b_nt=kparams.get("b_nt", 2), - gate_only=kparams.get("gate_only", False), - fuse_fp4_quant=fuse_fq, - fuse_sort_scale=fuse_fq, + gate_mode=kparams.get("gate_mode", "separated"), + a_scale_one=a_scale_one, + xcd_swizzle=kparams.get("xcd_swizzle", 0), + bias=bias, ) if isinstance(result, tuple): out_raw = result[0] - total_fp4_bytes = token_num * topk * (inter_dim // 2) - fp4_flat = out_raw.view(-1).view(torch.uint8)[:total_fp4_bytes] - return fp4_flat.view(dtypes.fp4x2).reshape(token_num, topk, -1) + if _out_dtype == "fp4": + total_fp4_bytes = token_num * topk * (inter_dim // 2) + fp4_flat = out_raw.view(-1).view(torch.uint8)[:total_fp4_bytes] + return fp4_flat.view(dtypes.fp4x2).reshape(token_num, topk, -1) + else: + # fuse_fp8: out_raw is fp8 tensor, shape (token_num, topk, inter_dim) + return out_raw.reshape(token_num, topk, -1) return result @staticmethod @@ -336,6 +473,7 @@ def run_flydsl_stage2_out( w2_scale_shuffled_flydsl, a2_scale, moe_buf, + bias, dtype, topk, kparams, @@ -368,6 +506,9 @@ def run_flydsl_stage2_out( sorted_weights=sorted_weights, sort_block_m=sort_block_m, persist=persist, + b_nt=kparams.get("b_nt", 0), + xcd_swizzle=kparams.get("xcd_swizzle", 0), + bias=bias, ) @staticmethod @@ -600,7 +741,16 @@ def generate_data( else: w1_qt = w1_qt.view(w1.shape[0], w1.shape[1], w1.shape[2] // 2) w2_qt = w2_qt.view(w2.shape[0], w2.shape[1], w2.shape[2] // 2) - score = torch.randn((token, expert), dtype=dtype) + if TUNE_MOE_EXPERT_BALANCE: + score = torch.zeros((token, expert), dtype=dtype) + start_col = 0 + end_col = topk + for token_id in range(token): + score[token_id, start_col:end_col] = 1.0 + start_col = end_col % expert + end_col = start_col + topk + else: + score = torch.randn((token, expert), dtype=dtype) topk_weights, topk_ids = fused_topk(input, score, topk, True) if q_type == QuantType.per_1x128: a1_qt, a1_scale = aiter.pertoken_quant( @@ -610,9 +760,9 @@ def generate_data( a1_scale = a1_scale.squeeze(-1) elif ( q_type == aiter.QuantType.per_1x32 - and (q_dtype_a in [dtypes.bf16, dtypes.fp16]) + and (q_dtype_a in [dtypes.bf16, dtypes.fp16, dtypes.fp8]) and q_dtype_w == dtypes.fp4x2 - ): # a16w4 + ): # a16w4 or a8w4 a1_qt = input.to(dtype) a1_scale = None else: @@ -796,6 +946,9 @@ def generate_data_2stages( blockM, device, ) + # Pre-bind so branches that skip shuffle_scale_* still reach `is None` below. + w1_scale_aiter = None + w2_scale_aiter = None if q_dtype_w == torch.int4: w1_qt_shffle_ck = rearrange_4bit_elements( convert_int8_to_uint32_int4( @@ -807,23 +960,32 @@ def generate_data_2stages( shuffle_weight(w2_qt, (16, 16), use_int4=True) ) ) - elif q_dtype_w == dtypes.fp4x2: + elif q_dtype_w == dtypes.fp4x2 and q_dtype_a == dtypes.fp4x2: w1_qt_shffle_ck = shuffle_weight(w1_qt, (16, 16)) w2_qt_shffle_ck = shuffle_weight(w2_qt, (16, 16)) + elif q_dtype_w == dtypes.fp4x2 and q_dtype_a == dtypes.fp8: + # a8w4 per_1x32 stage1 just support tune a1_cast now. + w1_qt_shffle_ck = shuffle_weight_a16w4(w1_qt, 16, True) + w1_scale_aiter = shuffle_scale_a16w4(w1_scale, expert, True) + w2_qt_shffle_ck = shuffle_weight_a16w4(w2_qt, 16, False) + w2_scale_aiter = shuffle_scale_a16w4(w2_scale, expert, False) else: w1_qt_shffle_ck = w1_qt_shffle w2_qt_shffle_ck = w2_qt_shffle - w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) - w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) + + if w1_scale_aiter is None: + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) w1_qt_shffle_flydsl = w1_qt_shffle_ck w2_qt_shffle_flydsl = w2_qt_shffle_ck w1_scale_flydsl = w1_scale_aiter w2_scale_flydsl = w2_scale_aiter + if stage == 1: if not doweight_stage1: sorted_weights = None - if q_type == QuantType.per_1x32: + if q_type == QuantType.per_1x32 and q_dtype_a == dtypes.fp4x2: a1_scale_fp4_sort = moe_mxfp4_sort( a1_scale, # a1_scale[: token * topk, :].view(token, topk, -1), sorted_ids=sorted_ids, @@ -834,6 +996,9 @@ def generate_data_2stages( else: a1_scale_fp4_sort = a1_scale + # For the _fp8 FlyDSL variant (a_scale_one=True): cast bf16 input to fp8. + a1_qt_fp8_cast = input.to(dtypes.fp8) + return ( a1_qt, # 0 w1_qt_shffle_ck, # 1 @@ -855,15 +1020,40 @@ def generate_data_2stages( w2_qt_shffle_flydsl, # 17 w1_scale_flydsl, # 18 w2_scale_flydsl, # 19 + a1_qt_fp8_cast, # 20 — fp8-cast input for _fp8 FlyDSL variant + None, # 21 — None placeholder (a1_scale=None for a8w4 torch ref) + ( + torch.clamp( + torch.randn( + (expert, inter_dim * 2), dtype=dtype, device=device + ), + -1.0, + 1.0, + ).to(torch.float32) + if ( + act_type == ActivationType.Swiglu + and q_type == QuantType.per_1x32 + and q_dtype_a == dtypes.fp8 + and dtype in [dtypes.bf16, dtypes.fp16] + ) + else None + ), # 22 — bias for stage1 (a8w4 only) ) elif stage == 2: + # a8w4: a1_scale is dummy non-None → torch_moe_stage1's per_1x32 + # branch would call mxfp4_to_f32(bf16), pass None to take a16w4 path. + ref_a1_scale = ( + None + if (q_type == QuantType.per_1x32 and q_dtype_a == dtypes.fp8) + else a1_scale + ) ref1 = FmoeTuner.run_torch_moe_stage1( a1_qt, w1_qt, w2_qt, topk_weights, topk_ids, - a1_scale=a1_scale, + a1_scale=ref_a1_scale, w1_scale=w1_scale, dtype=dtype, activation=act_type, @@ -871,6 +1061,9 @@ def generate_data_2stages( doweight_stage1=doweight_stage1, topk=topk, ) + # ref1 is always bf16 + ref1_bf16 = ref1 + if q_type == QuantType.per_1x128: ref1, ref_scale = aiter.pertoken_quant( ref1.view(ref1.shape[0], -1, 128), quant_dtype=q_dtype_a @@ -880,7 +1073,7 @@ def generate_data_2stages( a2_qt = ref1 a2_scale = ref_scale a2_scale_mxfp4_sort = a2_scale - elif q_type == QuantType.per_1x32: + elif q_type == QuantType.per_1x32 and q_dtype_a == dtypes.fp4x2: torch_quant = aiter.get_torch_quant(q_type) a2_qt, a2_scale = torch_quant(ref1, quant_dtype=q_dtype_a) a2_scale_mxfp4_sort = moe_mxfp4_sort( @@ -890,6 +1083,17 @@ def generate_data_2stages( token_num=token, block_size=blockM, ) + elif q_type == QuantType.per_1x32 and q_dtype_a == dtypes.fp8: + # FlyDSL stage2 receives fp8 input + a2_qt = ref1.to(dtypes.fp8) + M = sorted_ids.shape[0] + N = a2_qt.shape[-1] + a2_scale = torch.ones( + [token * topk, N // 32], dtype=dtypes.fp8_e8m0, device=a2_qt.device + ) + a2_scale_mxfp4_sort = torch.ones( + [M, N // 32], dtype=dtypes.fp8_e8m0, device=a2_qt.device + ) else: torch_quant = aiter.get_torch_quant(q_type) a2_qt, a2_scale = torch_quant(ref1, quant_dtype=q_dtype_a) @@ -897,8 +1101,9 @@ def generate_data_2stages( a2_qt = a2_qt.view(token, topk, -1) if doweight_stage1: sorted_weights = None + return ( - a2_qt, # 0 + a2_qt, # 0 — fp8 for FlyDSL (a8w4), fp4x2 for a4w4 w1_qt_shffle_ck, # 1 w2_qt_shffle_ck, # 2 a2_scale, # 3 @@ -918,6 +1123,22 @@ def generate_data_2stages( w2_qt_shffle_flydsl, # 17 w1_scale_flydsl, # 18 w2_scale_flydsl, # 19 + ref1_bf16, # 20 — bf16 stage1 output for torch ref (a8w4) + None, # 21 — None placeholder (a2_scale=None for a8w4 torch ref) + ( + torch.clamp( + torch.randn((expert, model_dim), dtype=dtype, device=device), + -1.0, + 1.0, + ).to(torch.float32) + if ( + act_type == ActivationType.Swiglu + and q_type == QuantType.per_1x32 + and q_dtype_a == dtypes.fp8 + and dtype in [dtypes.bf16, dtypes.fp16] + ) + else None + ), # 22 — bias for stage2 (a8w4 only) ) @staticmethod @@ -1025,13 +1246,15 @@ def run_torch_moe_stage1( w1_scale, sorted_ids=None, num_valid_ids=None, + w1_bias=None, dtype=dtypes.bf16, activation=ActivationType.Silu, quant_type=QuantType.No, doweight_stage1=False, topk=1, blockM=32, - fuse_fq=False, + fuse_fp4=False, + fuse_fp8=False, ): ref1 = torch_moe_stage1( a1_qt, @@ -1044,14 +1267,23 @@ def run_torch_moe_stage1( dtype=dtype, a1_scale=a1_scale, w1_scale=w1_scale, + w1_bias=w1_bias, doweight=doweight_stage1, ) token_num = a1_qt.shape[0] - if fuse_fq: + if fuse_fp4: from aiter.ops.quant import per_1x32_f4_quant a2, a2_scale = per_1x32_f4_quant(ref1, quant_dtype=dtypes.fp4x2) return a2.view(token_num, topk, -1) + elif fuse_fp8: + inter_dim = ref1.shape[-1] + a2_fp8_bytes, _a2_scale_e8m0 = torch_dynamic_mxfp8_quant( + ref1.reshape(-1, inter_dim) + ) + a2 = a2_fp8_bytes.view(dtypes.fp8).view(token_num, topk, inter_dim) + return a2 + if quant_type == QuantType.per_1x128: ref1, ref_scale = aiter.pertoken_quant( ref1.view(ref1.shape[0], -1, 128), quant_dtype=a1_qt.dtype @@ -1068,9 +1300,10 @@ def run_torch_moe_stage2( topk_ids, a2_scale, w2_scale, - dtype, - quant_type, - doweight_stage1, + w2_bias=None, + dtype=dtypes.bf16, + quant_type=QuantType.No, + doweight_stage1=False, ): return torch_moe_stage2( a2_qt, @@ -1082,6 +1315,7 @@ def run_torch_moe_stage2( quant_type, a2_scale=a2_scale, w2_scale=w2_scale, + w2_bias=w2_bias, doweight=not doweight_stage1, ) @@ -1668,7 +1902,13 @@ def gen_2stages_asm1_task(self, key, blockMs): kernels_list_csv.format(quantDtype=quantDtype, extraInfo=extraInfo) ) for blockM in blockMs: - if use_g1u1 and q_dtype_w != torch.int4: + # per_1x32 + fp4x2 is a8w4 (MX-FP8 act + MX-FP4 weight); no ASM kernel exists + # for this combo — the pertokenFp8 CSV only covers per_Token quant. + if ( + use_g1u1 + and q_dtype_w != torch.int4 + and not (q_type == QuantType.per_1x32 and q_dtype_w == dtypes.fp4x2) + ): for el in asm_kernels.get(blockM, []): tasks.append( ( @@ -1757,6 +1997,15 @@ def gen_2stages_task(self, key, blockMs): doweight_stage1, ) = info + _is_a8w4 = ( + q_dtype_a == dtypes.fp8 + and q_dtype_w == dtypes.fp4x2 + and q_type == QuantType.per_1x32 + ) + + if _is_a8w4: + return self._gen_2stages_task_cktile(info, blockMs) + _, ck_stage1_kernels = get_gemm1_kernels_list( dtype2str_dict[q_dtype_a], dtype2str_dict[q_dtype_w], @@ -1840,7 +2089,7 @@ def gen_2stages_task(self, key, blockMs): {}, FmoeTuner.run_torch_moe_stage1, ( - [0, 10, 11, 12, 13, 3, 4, 5, 8], + [0, 10, 11, 12, 13, 3, 4, 5, 8, 22], dtype, act_type, q_type, @@ -1852,68 +2101,19 @@ def gen_2stages_task(self, key, blockMs): (None), 0.01, 0.01, - True, + None, ) ) - for sk in splitk_list: - for kernel in ck_stage1_splitk_kernels.values(): - if kernel.MPerBlock != blockM: - continue - tag_name = f"{kernel.name}_sk{sk}" - tasks_ck.append( - ( - (info, "stage1", tag_name, blockM), - FmoeTuner.generate_data_2stages, - ( - token, - model_dim, - inter_dim, - expert, - topk, - act_type, - dtype, - q_dtype_a, - q_dtype_w, - q_type, - use_g1u1, - doweight_stage1, - blockM, - 1, - ), - FmoeTuner.ck_moe_stage1_fwd_out, - ( - [0, 1, 2, 5, 6, 7, 8, 15, 14], - dtype, - topk, - kernel.name, - blockM, - q_type, - act_type, - sk, - ), - {}, - FmoeTuner.run_torch_moe_stage1, - ( - [0, 10, 11, 12, 13, 3, 4, 5, 8], - dtype, - act_type, - q_type, - doweight_stage1, - topk, - blockM, - ), - {}, - (None), - 0.01, - 0.01, - True, - ) - ) - for kernel in ck_stage2_kernels.values(): if kernel.MPerBlock != blockM: continue + s2_ref_args = ( + [0, 10, 11, 12, 13, 3, 4, 22], + dtype, + q_type, + doweight_stage1, + ) tasks_ck.append( ( (info, "stage2", kernel.name, blockM), # tag @@ -1946,21 +2146,133 @@ def gen_2stages_task(self, key, blockMs): ), {}, FmoeTuner.run_torch_moe_stage2, - ( - [0, 10, 11, 12, 13, 3, 4], - dtype, - q_type, - doweight_stage1, - ), + s2_ref_args, {}, (None), 0.01, 0.01, - True, + None, ) ) return tasks_ck + def _gen_2stages_task_cktile(self, info, blockMs): + """A8W4 (fp8 activation + fp4 weight + per_1x32) uses cktile path.""" + tasks_ck = [] + ( + cu_num, + token, + model_dim, + inter_dim, + expert, + topk, + act_type, + dtype, + q_dtype_a, + q_dtype_w, + q_type, + use_g1u1, + doweight_stage1, + ) = info + + _gen_data_args_s1 = ( + token, + model_dim, + inter_dim, + expert, + topk, + act_type, + dtype, + q_dtype_a, + q_dtype_w, + q_type, + use_g1u1, + doweight_stage1, + ) + _gen_data_args_s2 = ( + token, + model_dim, + inter_dim, + expert, + topk, + act_type, + dtype, + q_dtype_a, + q_dtype_w, + q_type, + use_g1u1, + doweight_stage1, + ) + + for blockM in blockMs: + if blockM not in [32, 64] or not use_g1u1: + continue + + cktile_s1_name = f"cktile_a8w4_bm{blockM}" + tasks_ck.append( + ( + (info, "stage1", cktile_s1_name, blockM), + FmoeTuner.generate_data_2stages, + (*_gen_data_args_s1, blockM, 1), + FmoeTuner.cktile_moe_stage1_out, + ( + [20, 1, 2, 5, 6, 7, 8, 15, 22], + dtype, + topk, + blockM, + act_type, + ), + {}, + FmoeTuner.run_torch_moe_stage1, + ( + [0, 10, 11, 12, 13, 3, 4, 5, 8, 22], + dtype, + act_type, + q_type, + doweight_stage1, + topk, + blockM, + ), + {}, + (None), + 0.01, + 0.01, + cosine_diff_compare, + ) + ) + + cktile_s2_name = f"cktile_a8w4_bm{blockM}" + tasks_ck.append( + ( + (info, "stage2", cktile_s2_name, blockM), + FmoeTuner.generate_data_2stages, + (*_gen_data_args_s2, blockM, 2), + FmoeTuner.cktile_moe_stage2_out, + ( + [0, 1, 2, 5, 6, 7, 8, 15, 14, 22], + dtype, + topk, + blockM, + act_type, + ), + {}, + FmoeTuner.run_torch_moe_stage2, + ( + [20, 10, 11, 12, 13, 21, 4, 22], + dtype, + q_type, + doweight_stage1, + ), + {}, + (None), + 0.01, + 0.01, + cosine_diff_compare, + ) + ) + + return tasks_ck + def gen_flydsl_2stages_task(self, info, blockMs): tasks_flydsl = [] if not is_flydsl_available(): @@ -2005,23 +2317,45 @@ def gen_flydsl_2stages_task(self, info, blockMs): if blockM not in [32, 64, 128] or not use_g1u1: continue for kname, kparams in flydsl_s1_kernels.items(): - ktm = kparams["tile_m"] - if ktm != blockM and not (ktm == 16 and blockM == 32 and token <= 16): - continue - is_splitk = kparams.get("k_batch", 1) > 1 - if is_splitk: - fq_params = {**kparams, "fuse_fp4_quant": True} - s1_variants = [(kname + "_fq", fq_params, True)] + # (kernel_name, kparams, is_fp4, is_fp8) + # out_dtype encodes fused quant type: "fp4" or "fp8" + # a8w4 (a_dtype_str="fp8"): stage2 expects fp8 activations → out_dtype="fp8" + # a4w4 (a_dtype_str="fp4"): stage2 expects fp4 activations → out_dtype="fp4" + if a_dtype_str == "fp8": + fp8_params = { + **kparams, + "out_dtype": "fp8", + "a_scale_one": True, + "gate_mode": "interleave", + } + nonfused_params = {**kparams, "a_scale_one": True} + if is_splitk: + s1_variants = [(kname + "_fp8", fp8_params, False, True)] + else: + s1_variants = [ + (kname, nonfused_params, False, False), + (kname + "_fp8", fp8_params, False, True), + ] else: - s1_variants = [(kname, kparams, False)] - fq_params = {**kparams, "fuse_fp4_quant": True} - s1_variants.append((kname + "_fq", fq_params, True)) - - for s1_name, s1_params, is_fq in s1_variants: + fp4_params = {**kparams, "out_dtype": "fp4"} + if is_splitk: + s1_variants = [(kname + "_fp4", fp4_params, True, False)] + else: + s1_variants = [ + (kname, kparams, False, False), + (kname + "_fp4", fp4_params, True, False), + ] + + for s1_name, s1_params, is_fp4, is_fp8 in s1_variants: + s1_compare_fn = None + if is_fp8 or a_dtype_str == "fp8": + # a8w4: precision differs from torch ref; use cosine + # diff (logits diff) instead of checkAllclose. + s1_compare_fn = cosine_diff_compare ref_args_extra = ( - [0, 10, 11, 12, 13, 3, 4, 5, 8], + [0, 10, 11, 12, 13, 3, 4, 5, 8, 22], dtype, act_type, q_type, @@ -2029,8 +2363,17 @@ def gen_flydsl_2stages_task(self, info, blockMs): topk, blockM, ) - if is_fq: + if is_fp4: ref_args_extra = ref_args_extra + (True,) + elif is_fp8: + ref_args_extra = ref_args_extra + (False, True) + s1_ref_func = FmoeTuner.run_torch_moe_stage1 + s1_ref_args = ref_args_extra + s1_ref_kwargs = {} + s1_ref = None + + # _fp8 variant uses direct fp8-cast activation (index 20) + a1_idx = 20 if is_fp8 else 0 tasks_flydsl.append( ( (info, "stage1", s1_name, blockM), @@ -2053,22 +2396,23 @@ def gen_flydsl_2stages_task(self, info, blockMs): ), FmoeTuner.run_flydsl_stage1_out, ( - [0, 1, 5, 6, 7, 8, 15, 14], + [a1_idx, 1, 5, 6, 7, 8, 15, 14, 22], dtype, topk, s1_params, blockM, + q_dtype_a, q_type, act_type, ), {}, - FmoeTuner.run_torch_moe_stage1, - ref_args_extra, - {}, - (None), + s1_ref_func, + s1_ref_args, + s1_ref_kwargs, + s1_ref, 0.01, 0.01, - True, + s1_compare_fn, ) ) @@ -2081,6 +2425,28 @@ def gen_flydsl_2stages_task(self, info, blockMs): continue s2_kparams = {**kparams, "sort_block_m": blockM} s2_kname = kname if s2_tile_m == blockM else f"{kname}_sbm{blockM}" + + s2_ref_kwargs = {} + s2_compare_fn = None + if a_dtype_str == "fp8": + s2_compare_fn = cosine_diff_compare + # Use bf16 stage1 output (idx 20) and a2_scale=None (idx 21) + # so torch ref takes the a16w4 path instead of mxfp4_to_f32. + s2_ref_args = ( + [20, 10, 11, 12, 13, 21, 4, 22], + dtype, + q_type, + doweight_stage1, + ) + else: + s2_ref_args = ( + [0, 10, 11, 12, 13, 3, 4, 22], + dtype, + q_type, + doweight_stage1, + ) + s2_ref_func = FmoeTuner.run_torch_moe_stage2 + tasks_flydsl.append( ( (info, "stage2", s2_kname, blockM), @@ -2103,7 +2469,7 @@ def gen_flydsl_2stages_task(self, info, blockMs): ), FmoeTuner.run_flydsl_stage2_out, ( - [0, 17, 5, 6, 7, 8, 19, 14, 9], + [0, 17, 5, 6, 7, 8, 19, 14, 9, 22], dtype, topk, s2_kparams, @@ -2112,18 +2478,13 @@ def gen_flydsl_2stages_task(self, info, blockMs): act_type, ), {}, - FmoeTuner.run_torch_moe_stage2, - ( - [0, 10, 11, 12, 13, 3, 4], - dtype, - q_type, - doweight_stage1, - ), - {}, + s2_ref_func, + s2_ref_args, + s2_ref_kwargs, (None), 0.01, 0.01, - True, + s2_compare_fn, ) ) @@ -2663,47 +3024,85 @@ def post_process(self, results, args, topk=-1, fast_mode=False): self.failed = pd.concat([self.failed, failedf], axis=0) continue if q_type == QuantType.per_1x32: - from aiter.test_common import run_perftest - from aiter.ops.triton.quant.fused_mxfp4_quant import ( - fused_dynamic_mxfp4_quant_moe_sort, - ) + # For a4w4 (fp4 activation), a separate fp4-quant+sort step is needed + # between stage1 (bf16 output) and stage2 (fp4 input). Benchmark its + # cost and add it to non-fused kernels so comparisons are fair. + # + # For a8w4 (fp8 activation), non-fused paths assume bf16 stage1 output + # then a separate cast to fp8 before stage2; benchmark that cast + # (simple .to(dtypes.fp8)) and add it to kernels whose stage1 name does + # not end with _fp8 (those fuse the cast in stage1). + if q_dtype_a == dtypes.fp4x2: + from aiter.test_common import run_perftest + from aiter.ops.triton.quant.fused_mxfp4_quant import ( + fused_dynamic_mxfp4_quant_moe_sort, + ) + + us_qs_cache = {} + for bm in profileDF["block_m"].unique(): + bm_int = int(bm) + block_size = max(32, bm_int) + num_sorted = ( + (token * topk + block_size - 1) // block_size + ) * block_size + dummy_act = torch.randn( + token * topk, inter_dim, dtype=dtype, device="cuda" + ) + dummy_sorted_ids = torch.arange( + num_sorted, dtype=torch.int32, device="cuda" + ) + dummy_num_valid = torch.tensor( + [token * topk], dtype=torch.int32, device="cuda" + ) + _, us_qs = run_perftest( + fused_dynamic_mxfp4_quant_moe_sort, + dummy_act, + sorted_ids=dummy_sorted_ids, + num_valid_ids=dummy_num_valid, + token_num=token, + topk=topk, + block_size=block_size, + ) + us_qs_cache[bm] = round(us_qs, 4) + print( + f" quant_sort benchmark: blockM={bm_int}, us={us_qs_cache[bm]}" + ) + profileDF["us_quant_sort"] = profileDF["block_m"].map(us_qs_cache) + # _fp4 kernels already fuse the fp4-quant+sort; skip cost addition + is_fp4 = profileDF["kernelName1"].astype(str).str.endswith("_fp4") + profileDF.loc[~is_fp4, "us1"] = ( + profileDF.loc[~is_fp4, "us1"] + + profileDF.loc[~is_fp4, "us_quant_sort"] + ) + profileDF.drop(columns=["us_quant_sort"], inplace=True) + elif q_dtype_a == dtypes.fp8: + from aiter.test_common import run_perftest - us_qs_cache = {} - for bm in profileDF["block_m"].unique(): - bm_int = int(bm) - block_size = max(32, bm_int) - num_sorted = ( - (token * topk + block_size - 1) // block_size - ) * block_size dummy_act = torch.randn( token * topk, inter_dim, dtype=dtype, device="cuda" ) - dummy_sorted_ids = torch.arange( - num_sorted, dtype=torch.int32, device="cuda" - ) - dummy_num_valid = torch.tensor( - [token * topk], dtype=torch.int32, device="cuda" - ) - _, us_qs = run_perftest( - fused_dynamic_mxfp4_quant_moe_sort, - dummy_act, - sorted_ids=dummy_sorted_ids, - num_valid_ids=dummy_num_valid, - token_num=token, - topk=topk, - block_size=block_size, - ) - us_qs_cache[bm] = round(us_qs, 4) - print( - f" quant_sort benchmark: blockM={bm_int}, us={us_qs_cache[bm]}" + + def _act_to_fp8(x): + _scale_tmp = torch.ones( + [x.shape[0], x.shape[1] // 32], + dtype=dtypes.fp8_e8m0, + device=x.device, + ) + return x.to(dtypes.fp8) + + _, us_fp8_cast = run_perftest(_act_to_fp8, dummy_act) + us_fp8_cast = round(us_fp8_cast, 4) + print(f" fp8 activation cast benchmark: us={us_fp8_cast}") + us_qs_cache = {} + for bm in profileDF["block_m"].unique(): + us_qs_cache[bm] = us_fp8_cast + profileDF["us_quant_sort"] = profileDF["block_m"].map(us_qs_cache) + is_fp8 = profileDF["kernelName1"].astype(str).str.endswith("_fp8") + profileDF.loc[~is_fp8, "us1"] = ( + profileDF.loc[~is_fp8, "us1"] + + profileDF.loc[~is_fp8, "us_quant_sort"] ) - profileDF["us_quant_sort"] = profileDF["block_m"].map(us_qs_cache) - is_fq = profileDF["kernelName1"].astype(str).str.contains("_fq") - profileDF.loc[~is_fq, "us1"] = ( - profileDF.loc[~is_fq, "us1"] - + profileDF.loc[~is_fq, "us_quant_sort"] - ) - profileDF.drop(columns=["us_quant_sort"], inplace=True) + profileDF.drop(columns=["us_quant_sort"], inplace=True) profileDF["us"] = round(profileDF["us1"] + profileDF["us2"], 4) results = profileDF.apply( @@ -2734,7 +3133,7 @@ def post_process(self, results, args, topk=-1, fast_mode=False): tmpprofileDF.to_csv(args.profile_file, index=False) best_one = profileDF.loc[profileDF["us"].idxmin()].copy() print( - f"Tuning result for {key} is {best_one['block_m'] ,best_one['kernelName1'], best_one['kernelName2'], best_one['err1'], best_one['err2'], best_one['run_1stage']} {best_one['us']} us, {best_one['tflops']} TFLOPS, {best_one['bw']} GB/s" + f"Tuning result for {key} is {best_one['block_m'], best_one['kernelName1'], best_one['kernelName2'], best_one['err1'], best_one['err2'], best_one['run_1stage']} {best_one['us']} us, {best_one['tflops']} TFLOPS, {best_one['bw']} GB/s" ) best_one["act_type"] = str(best_one["act_type"]) best_one["q_type"] = str(best_one["q_type"]) @@ -2844,7 +3243,6 @@ def pre_process(self, args): if __name__ == "__main__": - key = [ "cu_num", "token", diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index b5e716fb40..73edb37f8e 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -10,6 +10,7 @@ from aiter.utility import fp4_utils from aiter.jit.utils.chip_info import get_gfx import argparse +import os import pandas as pd import logging @@ -29,6 +30,9 @@ torch.int4 = getattr(torch, "int4", torch.uint32) torch.set_default_device("cuda") +AITER_MOE_EXPERT_BALANCE = ( + os.environ.get("AITER_MOE_EXPERT_BALANCE", "False").lower() == "true" +) @benchmark() @@ -68,7 +72,17 @@ def test_fmoe( w2[:, :, -intermediate_pad:] = 0 w2[:, -hidden_pad:, :] = 0 exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) - score = torch.randn((token, E), dtype=dtype) + if AITER_MOE_EXPERT_BALANCE: + score = torch.zeros((token, E), dtype=dtype) + start_col = 0 + end_col = topk + for token_id in range(token): + score[token_id, start_col:end_col] = 1.0 + start_col = end_col % E + end_col = start_col + topk + else: + score = torch.randn((token, E), dtype=dtype) + topk_weights, topk_ids = fused_topk(input, score, topk, True) if qType == aiter.QuantType.per_Tensor: