@@ -204,6 +204,48 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,),
204204 CheckAllocates (allocate_info )(mod )
205205
206206
207+ def test_allocate_without_seq_stmt ():
208+ """
209+ Tests the case when an allocate statement does not have a sequence statement as its body.
210+ """
211+ # fmt: off
212+ @tvm .script .ir_module
213+ class Module :
214+ @T .prim_func
215+ def main (placeholder : T .Buffer [(8192 ,), "int8" ], ethosu_write : T .Buffer [(2048 ,), "int8" ], buffer_encoded : T .Buffer [(128 ,), "uint8" ], buffer_encoded_1 : T .Buffer [(32 ,), "uint8" ], buffer_encoded_2 : T .Buffer [(112 ,), "uint8" ], buffer_encoded_3 : T .Buffer [(32 ,), "uint8" ], buffer_encoded_4 : T .Buffer [(112 ,), "uint8" ], buffer_encoded_5 : T .Buffer [(32 ,), "uint8" ], buffer_encoded_6 : T .Buffer [(112 ,), "uint8" ], buffer_encoded_7 : T .Buffer [(32 ,), "uint8" ]) -> None :
216+ # function attr dict
217+ T .func_attr ({"from_legacy_te_schedule" : True , "global_symbol" : "main" , "tir.noalias" : True })
218+ T .preflattened_buffer (placeholder , [1 , 16 , 16 , 32 ], dtype = "int8" , data = placeholder .data )
219+ T .preflattened_buffer (ethosu_write , [1 , 16 , 16 , 8 ], dtype = "int8" , data = ethosu_write .data )
220+ # body
221+ placeholder_global = T .allocate ([128 ], "uint8" , "global" )
222+ placeholder_global_1 = T .allocate ([112 ], "uint8" , "global" )
223+ placeholder_global_2 = T .allocate ([112 ], "uint8" , "global" )
224+ placeholder_d_global = T .allocate ([32 ], "uint8" , "global" )
225+ placeholder_d_global_1 = T .allocate ([32 ], "uint8" , "global" )
226+ placeholder_d_global_2 = T .allocate ([32 ], "uint8" , "global" )
227+ placeholder_global_3 = T .allocate ([112 ], "uint8" , "global" )
228+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded [0 ], 128 , placeholder_global [0 ], dtype = "handle" ))
229+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded_1 [0 ], 32 , placeholder_d_global [0 ], dtype = "handle" ))
230+ T .evaluate (T .call_extern ("ethosu_conv2d" , "int8" , 16 , 16 , 32 , 16 , 0 , 16 , placeholder [0 ], 0 , 0 , 0 , T .float32 (0.5 ), 10 , "NHWC" , 512 , 32 , 1 , "int8" , 16 , 16 , 2 , 16 , 0 , 16 , ethosu_write [0 ], 0 , 0 , 0 , T .float32 (0.25 ), 14 , "NHWC" , 128 , 8 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , placeholder_global [0 ], 128 , 12 , placeholder_d_global [0 ], 32 , 0 , 0 , 0 , 0 , "NONE" , 0 , 0 , "TFL" , "NONE" , 0 , 0 , 0 , dtype = "handle" ))
231+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded_2 [0 ], 112 , placeholder_global_1 [0 ], dtype = "handle" ))
232+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded_3 [0 ], 32 , placeholder_d_global_1 [0 ], dtype = "handle" ))
233+ T .evaluate (T .call_extern ("ethosu_conv2d" , "int8" , 16 , 16 , 32 , 16 , 0 , 16 , placeholder [0 ], 0 , 0 , 0 , T .float32 (0.5 ), 10 , "NHWC" , 512 , 32 , 1 , "int8" , 16 , 16 , 2 , 16 , 0 , 16 , ethosu_write [2 ], 0 , 0 , 0 , T .float32 (0.25 ), 14 , "NHWC" , 128 , 8 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , placeholder_global_1 [0 ], 112 , 12 , placeholder_d_global_1 [0 ], 32 , 0 , 0 , 0 , 0 , "NONE" , 0 , 0 , "TFL" , "NONE" , 0 , 0 , 0 , dtype = "handle" ))
234+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded_4 [0 ], 112 , placeholder_global_2 [0 ], dtype = "handle" ))
235+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded_5 [0 ], 32 , placeholder_d_global_2 [0 ], dtype = "handle" ))
236+ placeholder_d_global_3 = T .allocate ([32 ], "uint8" , "global" )
237+ T .evaluate (T .call_extern ("ethosu_conv2d" , "int8" , 16 , 16 , 32 , 16 , 0 , 16 , placeholder [0 ], 0 , 0 , 0 , T .float32 (0.5 ), 10 , "NHWC" , 512 , 32 , 1 , "int8" , 16 , 16 , 2 , 16 , 0 , 16 , ethosu_write [4 ], 0 , 0 , 0 , T .float32 (0.25 ), 14 , "NHWC" , 128 , 8 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , placeholder_global_2 [0 ], 112 , 12 , placeholder_d_global_2 [0 ], 32 , 0 , 0 , 0 , 0 , "NONE" , 0 , 0 , "TFL" , "NONE" , 0 , 0 , 0 , dtype = "handle" ))
238+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded_6 [0 ], 112 , placeholder_global_3 [0 ], dtype = "handle" ))
239+ T .evaluate (T .call_extern ("ethosu_copy" , buffer_encoded_7 [0 ], 32 , placeholder_d_global_3 [0 ], dtype = "handle" ))
240+ T .evaluate (T .call_extern ("ethosu_conv2d" , "int8" , 16 , 16 , 32 , 16 , 0 , 16 , placeholder [0 ], 0 , 0 , 0 , T .float32 (0.5 ), 10 , "NHWC" , 512 , 32 , 1 , "int8" , 16 , 16 , 2 , 16 , 0 , 16 , ethosu_write [6 ], 0 , 0 , 0 , T .float32 (0.25 ), 14 , "NHWC" , 128 , 8 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , placeholder_global_3 [0 ], 112 , 12 , placeholder_d_global_3 [0 ], 32 , 0 , 0 , 0 , 0 , "NONE" , 0 , 0 , "TFL" , "NONE" , 0 , 0 , 0 , dtype = "handle" ))
241+ # fmt: on
242+
243+ mod = Module
244+ allocate_info = ExtractAllocateInfo ()(mod )
245+ mod = HoistAllocates ()(mod )
246+ CheckAllocates (allocate_info )(mod )
247+
248+
207249def test_multiple_prim_funcs ():
208250 @tvm .script .ir_module
209251 class Module :
0 commit comments