@@ -182,8 +182,7 @@ def te_layout_transform(data, name):
182182 )
183183
184184 def set_axis_sep (axis_sep : list , sch : tir .schedule , buffer_type : str ):
185- if len (axis_sep ) != 0 :
186- sch .set_axis_separator (primfunc_name , (buffer_type , 0 ), axis_separators = axis_sep )
185+ sch .set_axis_separator (primfunc_name , (buffer_type , 0 ), axis_separators = axis_sep )
187186
188187 index_map : tvm .tir .IndexMap = call .attrs .index_map
189188 pad_value = call .attrs .pad_value
@@ -199,7 +198,7 @@ def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
199198 input_axis_separators : tvm .tir .IndexMap .AXIS_SEPARATOR = call .attrs .input_axis_separators
200199
201200 # Convert to list from array
202- axis_separators = list ( map ( lambda x : x . value , axis_separators ))
201+ axis_separators = [ int ( sep ) for sep in axis_separators ]
203202 primfunc_name = "te_layout_transform"
204203 _ , padding_predicate = index_map .non_surjective_inverse (call .args [0 ].struct_info .shape )
205204 if not isinstance (padding_predicate , tvm .tir .expr .IntImm ):
@@ -214,7 +213,7 @@ def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
214213 sch .transform_layout (primfunc_name , ("write" , 0 ), index_map , pad_value )
215214 set_axis_sep (axis_separators , sch , "write" )
216215 if input_axis_separators is not None :
217- input_axis_separators = list ( map ( lambda x : x . value , input_axis_separators ))
216+ input_axis_separators = [ int ( sep ) for sep in input_axis_separators ]
218217 set_axis_sep (input_axis_separators , sch , "read" )
219218 gvar = bb .add_func (sch .mod ["main" ], primfunc_name )
220219 output_shape = index_map .map_shape (list (call_args [0 ].struct_info .shape ))
0 commit comments