@@ -685,47 +685,54 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
685685class TestFusedDequantMatmulAndroid (AndroidBeforeAfter ):
686686 # fmt: off
687687 @T .prim_func
688- def before (lv840 : T .Buffer ((T .int64 (512 ), T .int64 (12288 )), "uint32" ), lv841 : T .Buffer ((T .int64 (128 ), T .int64 (12288 )), "float16" ), p_rms_norm260 : T .handle , p_output0 : T .handle ):
688+ def before (lv452 : T .Buffer ((T .int64 (512 ), T .int64 (12288 )), "uint32" ), lv453 : T .Buffer ((T .int64 (128 ), T .int64 (12288 )), "float16" ), p_rms_norm130 : T .handle , transformer_h_0_attn_c_attn_bias3 : T . Buffer (( T . int64 ( 12288 ),), "float16" ) , p_output0 : T .handle ):
689689 T .func_attr ({"tir.noalias" : T .bool (True )})
690690 seq_len = T .int64 ()
691- rms_norm260 = T .match_buffer (p_rms_norm260 , (T .int64 (1 ), seq_len , T .int64 (4096 )), "float16" )
692- matmul_intermediate = T .match_buffer (p_output0 , (T .int64 (1 ), seq_len , T .int64 (12288 )), "float16" )
691+ rms_norm130 = T .match_buffer (p_rms_norm130 , (T .int64 (1 ), seq_len , T .int64 (4096 )), "float16" )
692+ T_add_intermediate_intermediate = T .match_buffer (p_output0 , (T .int64 (1 ), seq_len , T .int64 (12288 )), "float16" )
693693 # with T.block("root"):
694694 compute = T .alloc_buffer ((T .int64 (4096 ), T .int64 (12288 )), "float16" )
695695 dequantize_intermediate_intermediate = T .alloc_buffer ((T .int64 (4096 ), T .int64 (12288 )), "float16" )
696+ matmul_intermediate = T .alloc_buffer ((T .int64 (1 ), seq_len , T .int64 (12288 )), "float16" )
696697 for i0 , i1 in T .grid (T .int64 (4096 ), T .int64 (12288 )):
697698 with T .block ("compute" ):
698699 v_i0 , v_i1 = T .axis .remap ("SS" , [i0 , i1 ])
699- T .reads (lv840 [v_i0 // T .int64 (8 ), v_i1 ])
700+ T .reads (lv452 [v_i0 // T .int64 (8 ), v_i1 ])
700701 T .writes (compute [v_i0 , v_i1 ])
701- compute [v_i0 , v_i1 ] = T .Cast ("float16" , T .bitwise_and (T .shift_right (lv840 [v_i0 // T .int64 (8 ), v_i1 ], T .Cast ("uint32" , v_i0 % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 )))
702+ compute [v_i0 , v_i1 ] = T .Cast ("float16" , T .bitwise_and (T .shift_right (lv452 [v_i0 // T .int64 (8 ), v_i1 ], T .Cast ("uint32" , v_i0 % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 )))
702703 for i0 , i1 in T .grid (T .int64 (4096 ), T .int64 (12288 )):
703704 with T .block ("dequantize" ):
704705 v_i0 , v_i1 = T .axis .remap ("SS" , [i0 , i1 ])
705- T .reads (compute [v_i0 , v_i1 ], lv841 [v_i0 // T .int64 (32 ), v_i1 ])
706+ T .reads (compute [v_i0 , v_i1 ], lv453 [v_i0 // T .int64 (32 ), v_i1 ])
706707 T .writes (dequantize_intermediate_intermediate [v_i0 , v_i1 ])
707- dequantize_intermediate_intermediate [v_i0 , v_i1 ] = (compute [v_i0 , v_i1 ] - T .float16 (7 )) * lv841 [v_i0 // T .int64 (32 ), v_i1 ]
708+ dequantize_intermediate_intermediate [v_i0 , v_i1 ] = (compute [v_i0 , v_i1 ] - T .float16 (7 )) * lv453 [v_i0 // T .int64 (32 ), v_i1 ]
708709 for i0 , i1 , i2 , k in T .grid (T .int64 (1 ), seq_len , T .int64 (12288 ), T .int64 (4096 )):
709710 with T .block ("matmul" ):
710711 v_i0 , v_i1 , v_i2 , v_k = T .axis .remap ("SSSR" , [i0 , i1 , i2 , k ])
711- T .reads (rms_norm260 [v_i0 , v_i1 , v_k ], dequantize_intermediate_intermediate [v_k , v_i2 ])
712+ T .reads (rms_norm130 [v_i0 , v_i1 , v_k ], dequantize_intermediate_intermediate [v_k , v_i2 ])
712713 T .writes (matmul_intermediate [v_i0 , v_i1 , v_i2 ])
713714 with T .init ():
714715 matmul_intermediate [v_i0 , v_i1 , v_i2 ] = T .float16 (0 )
715- matmul_intermediate [v_i0 , v_i1 , v_i2 ] = matmul_intermediate [v_i0 , v_i1 , v_i2 ] + rms_norm260 [v_i0 , v_i1 , v_k ] * dequantize_intermediate_intermediate [v_k , v_i2 ]
716+ matmul_intermediate [v_i0 , v_i1 , v_i2 ] = matmul_intermediate [v_i0 , v_i1 , v_i2 ] + rms_norm130 [v_i0 , v_i1 , v_k ] * dequantize_intermediate_intermediate [v_k , v_i2 ]
717+ for ax0 , ax1 , ax2 in T .grid (T .int64 (1 ), seq_len , T .int64 (12288 )):
718+ with T .block ("T_add" ):
719+ v_ax0 , v_ax1 , v_ax2 = T .axis .remap ("SSS" , [ax0 , ax1 , ax2 ])
720+ T .reads (matmul_intermediate [v_ax0 , v_ax1 , v_ax2 ], transformer_h_0_attn_c_attn_bias3 [v_ax2 ])
721+ T .writes (T_add_intermediate_intermediate [v_ax0 , v_ax1 , v_ax2 ])
722+ T_add_intermediate_intermediate [v_ax0 , v_ax1 , v_ax2 ] = matmul_intermediate [v_ax0 , v_ax1 , v_ax2 ] + transformer_h_0_attn_c_attn_bias3 [v_ax2 ]
716723
717724 @T .prim_func
718- def expected (lv840 : T .Buffer ((T .int64 (512 ), T .int64 (12288 )), "uint32" ), lv841 : T .Buffer ((T .int64 (128 ), T .int64 (12288 )), "float16" ), p_rms_norm260 : T .handle , p_output0 : T .handle ):
725+ def expected (lv452 : T .Buffer ((T .int64 (512 ), T .int64 (12288 )), "uint32" ), lv453 : T .Buffer ((T .int64 (128 ), T .int64 (12288 )), "float16" ), p_rms_norm130 : T .handle , transformer_h_0_attn_c_attn_bias3 : T . Buffer (( T . int64 ( 12288 ),), "float16" ) , p_output0 : T .handle ):
719726 T .func_attr ({"global_symbol" : "main" , "tir.is_scheduled" : 1 , "tir.noalias" : T .bool (True )})
720727 seq_len = T .int64 ()
721- rms_norm260 = T .match_buffer (p_rms_norm260 , (T .int64 (1 ), seq_len , T .int64 (4096 )), "float16" )
722- matmul_intermediate = T .match_buffer (p_output0 , (T .int64 (1 ), seq_len , T .int64 (12288 )), "float16" )
728+ rms_norm130 = T .match_buffer (p_rms_norm130 , (T .int64 (1 ), seq_len , T .int64 (4096 )), "float16" )
729+ T_add_intermediate_intermediate = T .match_buffer (p_output0 , (T .int64 (1 ), seq_len , T .int64 (12288 )), "float16" )
723730 # with T.block("root"):
724731 dequantize_intermediate_intermediate_local = T .alloc_buffer ((T .int64 (4096 ), T .int64 (12288 )), "float16" , scope = "local" )
725- rms_norm260_pad_shared = T .alloc_buffer ((T .int64 (1 ), (seq_len + T .int64 (31 )) // T .int64 (32 ) * T .int64 (32 ), T .int64 (4096 )), "float16" , scope = "shared" )
732+ rms_norm130_pad_shared = T .alloc_buffer ((T .int64 (1 ), (seq_len + T .int64 (31 )) // T .int64 (32 ) * T .int64 (32 ), T .int64 (4096 )), "float16" , scope = "shared" )
726733 matmul_intermediate_pad_local = T .alloc_buffer ((T .int64 (1 ), (seq_len + T .int64 (31 )) // T .int64 (32 ) * T .int64 (32 ), T .int64 (12288 )), "float16" , scope = "local" )
727- lv840_local = T .alloc_buffer ((T .int64 (512 ), T .int64 (12288 )), "uint32" , scope = "local" )
728- lv841_local = T .alloc_buffer ((T .int64 (128 ), T .int64 (12288 )), "float16" , scope = "local" )
734+ lv452_local = T .alloc_buffer ((T .int64 (512 ), T .int64 (12288 )), "uint32" , scope = "local" )
735+ lv453_local = T .alloc_buffer ((T .int64 (128 ), T .int64 (12288 )), "float16" , scope = "local" )
729736 for i2_0 in T .thread_binding (T .int64 (48 ), thread = "blockIdx.x" ):
730737 for i0_i1_fused_0 in T .thread_binding ((seq_len + T .int64 (31 )) // T .int64 (32 ), thread = "blockIdx.y" ):
731738 for i2_1 in T .thread_binding (T .int64 (32 ), thread = "threadIdx.x" ):
@@ -743,57 +750,57 @@ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T
743750 for ax0 in range (T .int64 (4 )):
744751 for ax1_0 in T .thread_binding (T .int64 (32 ), thread = "threadIdx.x" ):
745752 for ax1_1 in T .vectorized (T .int64 (8 )):
746- with T .block ("rms_norm260_pad " ):
753+ with T .block ("rms_norm130_pad " ):
747754 v0 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
748755 v1 = T .axis .spatial ((seq_len + T .int64 (31 )) // T .int64 (32 ) * T .int64 (32 ), i0_i1_fused_0 * T .int64 (32 ) + i0_i1_fused_1 * T .int64 (4 ) + ax0 )
749756 v2 = T .axis .spatial (T .int64 (4096 ), k_0 * T .int64 (256 ) + ax1_0 * T .int64 (8 ) + ax1_1 )
750- T .reads (rms_norm260 [v0 , v1 , v2 ])
751- T .writes (rms_norm260_pad_shared [v0 , v1 , v2 ])
752- rms_norm260_pad_shared [v0 , v1 , v2 ] = T .if_then_else (v1 < seq_len , rms_norm260 [v0 , v1 , v2 ], T .float16 (0 ))
757+ T .reads (rms_norm130 [v0 , v1 , v2 ])
758+ T .writes (rms_norm130_pad_shared [v0 , v1 , v2 ])
759+ rms_norm130_pad_shared [v0 , v1 , v2 ] = T .if_then_else (v1 < seq_len , rms_norm130 [v0 , v1 , v2 ], T .float16 (0 ))
753760 for k_1 in range (T .int64 (8 )):
754761 for ax0 in T .vectorized (T .int64 (8 )):
755- with T .block ("lv841_local " ):
762+ with T .block ("lv453_local " ):
756763 v0 = T .axis .spatial (T .int64 (128 ), k_0 * T .int64 (8 ) + k_1 )
757764 v1 = T .axis .spatial (T .int64 (12288 ), i2_0 * T .int64 (256 ) + i2_1 * T .int64 (8 ) + ax0 )
758- T .reads (lv841 [v0 , v1 ])
759- T .writes (lv841_local [v0 , v1 ])
760- lv841_local [v0 , v1 ] = lv841 [v0 , v1 ]
765+ T .reads (lv453 [v0 , v1 ])
766+ T .writes (lv453_local [v0 , v1 ])
767+ lv453_local [v0 , v1 ] = lv453 [v0 , v1 ]
761768 for k_2 in range (T .int64 (4 )):
762769 for ax0 in T .vectorized (T .int64 (8 )):
763- with T .block ("lv840_local " ):
770+ with T .block ("lv452_local " ):
764771 v0 = T .axis .spatial (T .int64 (512 ), k_0 * T .int64 (32 ) + k_1 * T .int64 (4 ) + k_2 )
765772 v1 = T .axis .spatial (T .int64 (12288 ), i2_0 * T .int64 (256 ) + i2_1 * T .int64 (8 ) + ax0 )
766- T .reads (lv840 [v0 , v1 ])
767- T .writes (lv840_local [v0 , v1 ])
768- lv840_local [v0 , v1 ] = lv840 [v0 , v1 ]
773+ T .reads (lv452 [v0 , v1 ])
774+ T .writes (lv452_local [v0 , v1 ])
775+ lv452_local [v0 , v1 ] = lv452 [v0 , v1 ]
769776 for k_3 in range (T .int64 (8 )):
770777 for ax0 in T .vectorized (T .int64 (8 )):
771778 with T .block ("dequantize" ):
772779 v_i0 = T .axis .spatial (T .int64 (4096 ), k_0 * T .int64 (256 ) + k_1 * T .int64 (32 ) + k_2 * T .int64 (8 ) + k_3 )
773780 v_i1 = T .axis .spatial (T .int64 (12288 ), i2_0 * T .int64 (256 ) + i2_1 * T .int64 (8 ) + ax0 )
774- T .reads (lv840_local [v_i0 // T .int64 (8 ), v_i1 ], lv841_local [v_i0 // T .int64 (32 ), v_i1 ])
781+ T .reads (lv452_local [v_i0 // T .int64 (8 ), v_i1 ], lv453_local [v_i0 // T .int64 (32 ), v_i1 ])
775782 T .writes (dequantize_intermediate_intermediate_local [v_i0 , v_i1 ])
776- dequantize_intermediate_intermediate_local [v_i0 , v_i1 ] = (T .Cast ("float16" , T .bitwise_and (T .shift_right (lv840_local [v_i0 // T .int64 (8 ), v_i1 ], T .Cast ("uint32" , v_i0 % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 ))) - T .float16 (7 )) * lv841_local [v_i0 // T .int64 (32 ), v_i1 ]
783+ dequantize_intermediate_intermediate_local [v_i0 , v_i1 ] = (T .Cast ("float16" , T .bitwise_and (T .shift_right (lv452_local [v_i0 // T .int64 (8 ), v_i1 ], T .Cast ("uint32" , v_i0 % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 ))) - T .float16 (7 )) * lv453_local [v_i0 // T .int64 (32 ), v_i1 ]
777784 for i0_i1_fused_2 in range (T .int64 (4 )):
778785 for i2_2 in T .vectorized (T .int64 (8 )):
779786 with T .block ("matmul_update" ):
780787 v_i0 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
781788 v_i1 = T .axis .spatial ((seq_len + T .int64 (31 )) // T .int64 (32 ) * T .int64 (32 ), i0_i1_fused_0 * T .int64 (32 ) + i0_i1_fused_1 * T .int64 (4 ) + i0_i1_fused_2 )
782789 v_i2 = T .axis .spatial (T .int64 (12288 ), i2_0 * T .int64 (256 ) + i2_1 * T .int64 (8 ) + i2_2 )
783790 v_k = T .axis .reduce (T .int64 (4096 ), k_0 * T .int64 (256 ) + k_1 * T .int64 (32 ) + k_2 * T .int64 (8 ) + k_3 )
784- T .reads (matmul_intermediate_pad_local [v_i0 , v_i1 , v_i2 ], rms_norm260_pad_shared [v_i0 , v_i1 , v_k ], dequantize_intermediate_intermediate_local [v_k , v_i2 ])
791+ T .reads (matmul_intermediate_pad_local [v_i0 , v_i1 , v_i2 ], rms_norm130_pad_shared [v_i0 , v_i1 , v_k ], dequantize_intermediate_intermediate_local [v_k , v_i2 ])
785792 T .writes (matmul_intermediate_pad_local [v_i0 , v_i1 , v_i2 ])
786- matmul_intermediate_pad_local [v_i0 , v_i1 , v_i2 ] = matmul_intermediate_pad_local [v_i0 , v_i1 , v_i2 ] + rms_norm260_pad_shared [v_i0 , v_i1 , v_k ] * dequantize_intermediate_intermediate_local [v_k , v_i2 ]
787- for ax0 in range ( T .int64 (4 )):
788- for ax1 in T .vectorized (T .int64 (8 )):
789- with T .block ("matmul_intermediate_pad " ):
790- v0 = T .axis .spatial (T .int64 (1 ), T . int64 ( 0 ) )
791- v1 = T .axis .spatial (seq_len , i0_i1_fused_0 * T .int64 (32 ) + i0_i1_fused_1 * T .int64 (4 ) + ax0 )
792- v2 = T .axis .spatial (T .int64 (12288 ), i2_0 * T .int64 (256 ) + i2_1 * T .int64 (8 ) + ax1 )
793- T .where (( i0_i1_fused_0 - ( seq_len + T . int64 ( 31 )) // T . int64 ( 32 ) < T . int64 ( 0 ) or i0_i1_fused_0 == T . int64 ( 0 )) and i0_i1_fused_0 * T .int64 (32 ) + i0_i1_fused_1 * T .int64 (4 ) + ax0 < seq_len )
794- T .reads (matmul_intermediate_pad_local [v0 , v1 , v2 ])
795- T .writes (matmul_intermediate [ v0 , v1 , v2 ])
796- matmul_intermediate [ v0 , v1 , v2 ] = matmul_intermediate_pad_local [v0 , v1 , v2 ]
793+ matmul_intermediate_pad_local [v_i0 , v_i1 , v_i2 ] = matmul_intermediate_pad_local [v_i0 , v_i1 , v_i2 ] + rms_norm130_pad_shared [v_i0 , v_i1 , v_k ] * dequantize_intermediate_intermediate_local [v_k , v_i2 ]
794+ for ax0 , ax1 in T . grid ( T . int64 ( 1 ), T .int64 (4 )):
795+ for ax2 in T .vectorized (T .int64 (8 )):
796+ with T .block ("T_add " ):
797+ v_ax0 = T .axis .spatial (T .int64 (1 ), ax0 )
798+ v_ax1 = T .axis .spatial (seq_len , i0_i1_fused_0 * T .int64 (32 ) + i0_i1_fused_1 * T .int64 (4 ) + ax1 )
799+ v_ax2 = T .axis .spatial (T .int64 (12288 ), i2_0 * T .int64 (256 ) + i2_1 * T .int64 (8 ) + ax2 )
800+ T .where (i0_i1_fused_0 * T .int64 (32 ) + i0_i1_fused_1 * T .int64 (4 ) + ax1 < seq_len )
801+ T .reads (matmul_intermediate_pad_local [v_ax0 , v_ax1 , v_ax2 ], transformer_h_0_attn_c_attn_bias3 [ v_ax2 ])
802+ T .writes (T_add_intermediate_intermediate [ v_ax0 , v_ax1 , v_ax2 ])
803+ T_add_intermediate_intermediate [ v_ax0 , v_ax1 , v_ax2 ] = matmul_intermediate_pad_local [v_ax0 , v_ax1 , v_ax2 ] + transformer_h_0_attn_c_attn_bias3 [ v_ax2 ]
797804 # fmt: on
798805
799806
0 commit comments