@@ -62,9 +62,14 @@ def main( # type: ignore
6262 for i0 , i1 , i2 , i3 , i4 in T .grid (1 , 1 , 16 , 16 , 3 ):
6363 with T .block ("T_layout_trans" ):
6464 ax0 , ax1 , ax2 , ax3 , ax4 = T .axis .remap ("SSSSS" , [i0 , i1 , i2 , i3 , i4 ])
65- T .reads (placeholder [0 , ax4 , ax2 , ax3 ])
65+ T .reads (placeholder [ax0 , ax1 * 3 + ax4 , ax2 , ax3 ])
6666 T .writes (T_layout_trans [ax0 , ax1 , ax2 , ax3 , ax4 ])
67- T_layout_trans [ax0 , ax1 , ax2 , ax3 , ax4 ] = placeholder [0 , ax4 , ax2 , ax3 ]
67+ T_layout_trans [ax0 , ax1 , ax2 , ax3 , ax4 ] = T .if_then_else (
68+ ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16 , # type: ignore
69+ placeholder [ax0 , ax1 * 3 + ax4 , ax2 , ax3 ],
70+ T .float32 (0 ),
71+ dtype = "float32" ,
72+ )
6873
6974
7075@tvm .script .ir_module
@@ -79,19 +84,18 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B
7984 for i0 , i1 , i2 , i3 , i4 in T .grid (1 , 1 , 20 , 20 , 3 ):
8085 with T .block ("data_pad" ):
8186 i0_1 , i1_1 , i2_1 , i3_1 , i4_1 = T .axis .remap ("SSSSS" , [i0 , i1 , i2 , i3 , i4 ])
82- T .reads (placeholder [0 , 0 , i2_1 - 2 , i3_1 - 2 , i4_1 ]) # type: ignore
87+ T .reads (placeholder [i0_1 , i1_1 , i2_1 - 2 , i3_1 - 2 , i4_1 ])
8388 T .writes (data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ])
84- data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ] = T .if_then_else (2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18 , placeholder [0 , 0 , i2_1 - 2 , i3_1 - 2 , i4_1 ], T .float32 (0 ), dtype = "float32" ) # type: ignore # pylint: disable=R1716
89+ data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ] = T .if_then_else (2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18 , placeholder [i0_1 , i1_1 , i2_1 - 2 , i3_1 - 2 , i4_1 ], T .float32 (0 ), dtype = "float32" ) # type: ignore # pylint: disable=R1716
8590 for i0 , i1 , i2 , i3 , i4 , i5 , i6 , i7 in T .grid (1 , 2 , 16 , 16 , 4 , 3 , 5 , 5 ):
8691 with T .block ("conv2d_NCHWc" ):
8792 n , oc_chunk , oh , ow , oc_block , ic , kh , kw = T .axis .remap ("SSSSSRRR" , [i0 , i1 , i2 , i3 , i4 , i5 , i6 , i7 ])
88- T .reads (data_pad [0 , 0 , oh + kh , ow + kw , ic ], placeholder_1 [oc_chunk , 0 , kh , kw , ic , oc_block ]) # type: ignore
93+ T .reads (data_pad [n , ic // 3 , oh + kh , ow + kw , ic % 3 ], placeholder_1 [oc_chunk , ic // 3 , kh , kw , ic % 3 , oc_block ]) # type: ignore
8994 T .writes (conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ])
9095 T .block_attr ({"workload" :["conv2d_NCHWc.x86" , ["TENSOR" , [1 , 1 , 16 , 16 , 3 ], "float32" ], ["TENSOR" , [2 , 1 , 5 , 5 , 3 , 4 ], "float32" ], [1 , 1 ], [2 , 2 , 2 , 2 ], [1 , 1 ], "NCHW3c" , "NCHW4c" , "float32" ]})
9196 with T .init ():
9297 conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = T .float32 (0 )
93- conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] + data_pad [0 , 0 , oh + kh , ow + kw , ic ] * placeholder_1 [oc_chunk , 0 , kh , kw , ic , oc_block ] # type: ignore
94-
98+ conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] + data_pad [n , ic // 3 , oh + kh , ow + kw , ic % 3 ] * placeholder_1 [oc_chunk , ic // 3 , kh , kw , ic % 3 , oc_block ] # type: ignore
9599
96100@tvm .script .ir_module
97101class tvmgen_default_fused_layout_transform_1 :
@@ -104,9 +108,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.
104108 for i0 , i1 , i2 , i3 in T .grid (1 , 8 , 16 , 16 ):
105109 with T .block ("T_layout_trans" ):
106110 ax0 , ax1 , ax2 , ax3 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
107- T .reads (placeholder [0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ]) # type: ignore
111+ T .reads (placeholder [ax0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ]) # type: ignore
108112 T .writes (T_layout_trans [ax0 , ax1 , ax2 , ax3 ])
109- T_layout_trans [ax0 , ax1 , ax2 , ax3 ] = placeholder [0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ] # type: ignore
113+ T_layout_trans [ax0 , ax1 , ax2 , ax3 ] = T . if_then_else ( ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16 , placeholder [ax0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ], T . float32 ( 0 ), dtype = "float32" ) # type: ignore
110114
111115# fmt: on
112116# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
0 commit comments