2222import pytest
2323
2424
25- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ))
26- def test_vectorize_loop (extent ):
25+ simple_target = tvm .target .Target ("llvm -mtriple=x86_64-linux-gnu" )
26+ sve_target = tvm .target .Target ("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve" )
27+
28+
29+ @pytest .mark .parametrize ("extent, target" , [(4 , simple_target ), (T .vscale () * 4 , sve_target )])
30+ def test_vectorize_loop (extent , target ):
2731 @I .ir_module
2832 class Before :
2933 @T .prim_func
@@ -37,8 +41,9 @@ class After:
3741 def main (A : T .Buffer ((16 ,), "float32" )):
3842 A [T .Ramp (0 , 1 , extent )] = T .Broadcast (1 , extent )
3943
40- mod = tvm .tir .transform .VectorizeLoop ()(Before )
41- tvm .ir .assert_structural_equal (mod , After )
44+ with tvm .target .Target (target ):
45+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
46+ tvm .ir .assert_structural_equal (mod , After )
4247
4348
4449def test_vectorize_vector ():
@@ -70,8 +75,9 @@ def main(A: T.Buffer((25,), "float32")):
7075 A [j * 4 : j * 4 + 4 ] = T .Broadcast (T .float32 (1 ), 4 )
7176
7277 error_msg = f"Creating scalable vectors from existing vectors is not supported."
73- with pytest .raises (tvm .error .InternalError , match = error_msg ):
74- tvm .tir .transform .VectorizeLoop ()(Module )
78+ with tvm .target .Target (sve_target ):
79+ with pytest .raises (tvm .error .InternalError , match = error_msg ):
80+ tvm .tir .transform .VectorizeLoop ()(Module )
7581
7682
7783def test_vectorize_vector_scalable_error2 ():
@@ -99,7 +105,8 @@ def main(A: T.Buffer((25,), "float32")):
99105
100106 error_msg = f"Vectorizing over existing scalable vectors is not supported."
101107 with pytest .raises (tvm .error .InternalError , match = error_msg ):
102- tvm .tir .transform .VectorizeLoop ()(Module )
108+ with tvm .target .Target (sve_target ):
109+ tvm .tir .transform .VectorizeLoop ()(Module )
103110
104111
105112def test_vectorize_vector_scalable_error4 ():
@@ -114,11 +121,12 @@ def main(A: T.Buffer((25,), "float32")):
114121
115122 error_msg = f"Creating scalable vectors from existing vectors is not supported."
116123 with pytest .raises (tvm .error .InternalError , match = error_msg ):
117- tvm .tir .transform .VectorizeLoop ()(Module )
124+ with tvm .target .Target (sve_target ):
125+ tvm .tir .transform .VectorizeLoop ()(Module )
118126
119127
120- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ) )
121- def test_vectorize_with_if (extent ):
128+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( T .vscale () * 4 , sve_target )] )
129+ def test_vectorize_with_if (extent , target ):
122130 @I .ir_module
123131 class Before :
124132 @T .prim_func
@@ -143,8 +151,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
143151 if i_s < n :
144152 A [i_s ] = T .float32 (2 )
145153
146- mod = tvm .tir .transform .VectorizeLoop ()(Before )
147- tvm .ir .assert_structural_equal (mod , After )
154+ with tvm .target .Target (target ):
155+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
156+ tvm .ir .assert_structural_equal (mod , After )
148157
149158
150159def test_vectorize_with_if_cond_int64 ():
@@ -157,8 +166,8 @@ def test_vectorize_with_if_cond_int64():
157166 f = tvm .build (s , [A , B ], "llvm" )
158167
159168
160- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ) )
161- def test_vectorize_let (extent ):
169+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( T .vscale () * 4 , sve_target )] )
170+ def test_vectorize_let (extent , target ):
162171 @I .ir_module
163172 class Before :
164173 @T .prim_func
@@ -174,12 +183,13 @@ def main(A: T.Buffer((25,), "float32")):
174183 v = A [T .Ramp (0 , 1 , extent )] + T .Broadcast (T .float32 (1 ), extent )
175184 A [T .Ramp (0 , 1 , extent )] = v + T .Broadcast (T .float32 (2 ), extent )
176185
177- mod = tvm .tir .transform .VectorizeLoop ()(Before )
178- tvm .ir .assert_structural_equal (mod , After )
186+ with tvm .target .Target (target ):
187+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
188+ tvm .ir .assert_structural_equal (mod , After )
179189
180190
181- @pytest .mark .parametrize ("extent" , (4 , tvm .tir .vscale () * 4 ) )
182- def test_vectorize_with_le_cond (extent ):
191+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( tvm .tir .vscale () * 4 , sve_target )] )
192+ def test_vectorize_with_le_cond (extent , target ):
183193 n = te .var ("n" )
184194 ib = tvm .tir .ir_builder .create ()
185195 A = ib .pointer ("float32" , name = "A" )
@@ -189,14 +199,16 @@ def test_vectorize_with_le_cond(extent):
189199 stmt = ib .get ()
190200
191201 mod = tvm .IRModule .from_expr (tvm .tir .PrimFunc ([A , n ], stmt ))
192- stmt = tvm .tir .transform .VectorizeLoop ()(mod )["main" ].body
193202
194- # Check that the loop was't vectorised
195- assert isinstance (stmt , tvm .tir .For )
203+ with tvm .target .Target (target ):
204+ stmt = tvm .tir .transform .VectorizeLoop ()(mod )["main" ].body
205+
206+ # Check that the loop was't vectorised
207+ assert isinstance (stmt , tvm .tir .For )
196208
197209
198- @pytest .mark .parametrize ("extent" , (4 , tvm .tir .vscale () * 4 ) )
199- def test_vectorize_with_ge_cond (extent ):
210+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( tvm .tir .vscale () * 4 , sve_target )] )
211+ def test_vectorize_with_ge_cond (extent , target ):
200212 n = te .var ("n" )
201213 ib = tvm .tir .ir_builder .create ()
202214 A = ib .pointer ("float32" , name = "A" )
@@ -206,14 +218,16 @@ def test_vectorize_with_ge_cond(extent):
206218 stmt = ib .get ()
207219
208220 mod = tvm .IRModule .from_expr (tvm .tir .PrimFunc ([A , n ], stmt ))
209- stmt = tvm .tir .transform .VectorizeLoop ()(mod )["main" ].body
210221
211- # Check that the loop wasn't vectorised
212- assert isinstance ( stmt , tvm .tir .For )
222+ with tvm . target . Target ( target ):
223+ stmt = tvm .tir .transform . VectorizeLoop ()( mod )[ "main" ]. body
213224
225+ # Check that the loop wasn't vectorised
226+ assert isinstance (stmt , tvm .tir .For )
214227
215- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ))
216- def test_vectorize_if_then_else_scalarize (extent ):
228+
229+ @pytest .mark .parametrize ("extent, target" , [(4 , simple_target ), (T .vscale () * 4 , sve_target )])
230+ def test_vectorize_if_then_else_scalarize (extent , target ):
217231 @I .ir_module
218232 class Before :
219233 @T .prim_func
@@ -228,12 +242,13 @@ def main(A: T.Buffer((25,), "float32")):
228242 for i_s in range (extent ):
229243 A [i_s ] = T .if_then_else (i_s > 0 , A [i_s ] + T .float32 (1 ), A [i_s ])
230244
231- mod = tvm .tir .transform .VectorizeLoop ()(Before )
232- tvm .ir .assert_structural_equal (mod , After )
245+ with tvm .target .Target (target ):
246+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
247+ tvm .ir .assert_structural_equal (mod , After )
233248
234249
235- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ) )
236- def test_vectorize_if_then_else_vector (extent ):
250+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( T .vscale () * 4 , sve_target )] )
251+ def test_vectorize_if_then_else_vector (extent , target ):
237252 @I .ir_module
238253 class Before :
239254 @T .prim_func
@@ -251,8 +266,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32):
251266 i > 0 , A [T .Ramp (i * extent , 1 , extent )], T .Broadcast (0 , extent )
252267 )
253268
254- mod = tvm .tir .transform .VectorizeLoop ()(Before )
255- tvm .ir .assert_structural_equal (mod , After )
269+ with tvm .target .Target (target ):
270+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
271+ tvm .ir .assert_structural_equal (mod , After )
256272
257273
258274def test_vectorize_while_fail ():
@@ -311,9 +327,10 @@ def test_vectorize_dtype_mismatch():
311327
312328
313329@pytest .mark .parametrize (
314- "extent, vec_str" , [(16 , "float32x16" ), (T .vscale () * 8 , "float32xvscalex8" )]
330+ "extent, vec_str, target" ,
331+ [(16 , "float32x16" , simple_target ), (T .vscale () * 8 , "float32xvscalex8" , sve_target )],
315332)
316- def test_vectorize_with_reinterpret (extent , vec_str ):
333+ def test_vectorize_with_reinterpret (extent , vec_str , target ):
317334 @I .ir_module
318335 class Before :
319336 @T .prim_func
@@ -327,11 +344,12 @@ class After:
327344 def main (A : T .Buffer ((16 ,), "int32" ), B : T .Buffer ((16 ,), "float32" )):
328345 B [T .Ramp (0 , 1 , extent )] = T .reinterpret (vec_str , A [T .Ramp (0 , 1 , extent )])
329346
330- mod = tvm .tir .transform .VectorizeLoop ()(Before )
331- tvm .ir .assert_structural_equal (mod , After )
347+ with tvm .target .Target (target ):
348+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
349+ tvm .ir .assert_structural_equal (mod , After )
332350
333351
334- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ) )
352+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( T .vscale () * 4 , sve_target )] )
335353@pytest .mark .parametrize (
336354 "op" ,
337355 (
@@ -352,7 +370,7 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
352370 T .NE ,
353371 ),
354372)
355- def test_vectorize_binary (op , extent ):
373+ def test_vectorize_binary (op , extent , target ):
356374 @I .ir_module
357375 class Before :
358376 @T .prim_func
@@ -366,13 +384,14 @@ class After:
366384 def main (A : T .Buffer ((25 ,), "float32" ), B : T .Buffer ((25 ,), "float32" )):
367385 A [T .Ramp (0 , 1 , extent )] = op (T .Broadcast (T .float32 (3 ), extent ), B [T .Ramp (0 , 1 , extent )])
368386
369- mod = tvm .tir .transform .VectorizeLoop ()(Before )
370- tvm .ir .assert_structural_equal (mod , After )
387+ with tvm .target .Target (target ):
388+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
389+ tvm .ir .assert_structural_equal (mod , After )
371390
372391
373- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ) )
392+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( T .vscale () * 4 , sve_target )] )
374393@pytest .mark .parametrize ("op" , (T .And , T .Or ))
375- def test_vectorize_logical (op , extent ):
394+ def test_vectorize_logical (op , extent , target ):
376395 @I .ir_module
377396 class Before :
378397 @T .prim_func
@@ -386,12 +405,13 @@ class After:
386405 def main (A : T .Buffer ((25 ,), "bool" ), B : T .Buffer ((25 ,), "bool" )):
387406 A [T .Ramp (0 , 1 , extent )] = op (T .Broadcast (T .bool (1 ), extent ), B [T .Ramp (0 , 1 , extent )])
388407
389- mod = tvm .tir .transform .VectorizeLoop ()(Before )
390- tvm .ir .assert_structural_equal (mod , After )
408+ with tvm .target .Target (target ):
409+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
410+ tvm .ir .assert_structural_equal (mod , After )
391411
392412
393- @pytest .mark .parametrize ("extent" , (4 , T .vscale () * 4 ) )
394- def test_vectorize_select (extent ):
413+ @pytest .mark .parametrize ("extent, target " , [ (4 , simple_target ), ( T .vscale () * 4 , sve_target )] )
414+ def test_vectorize_select (extent , target ):
395415 @I .ir_module
396416 class Before :
397417 @T .prim_func
@@ -409,12 +429,16 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
409429 B [T .Ramp (0 , 1 , extent )],
410430 )
411431
412- mod = tvm .tir .transform .VectorizeLoop ()(Before )
413- tvm .ir .assert_structural_equal (mod , After )
432+ with tvm .target .Target (target ):
433+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
434+ tvm .ir .assert_structural_equal (mod , After )
414435
415436
416- @pytest .mark .parametrize ("extent, vec_str" , [(4 , "int32x4" ), (T .vscale () * 4 , "int32xvscalex4" )])
417- def test_vectorize_cast (extent , vec_str ):
437+ @pytest .mark .parametrize (
438+ "extent, vec_str, target" ,
439+ [(4 , "int32x4" , simple_target ), (T .vscale () * 4 , "int32xvscalex4" , sve_target )],
440+ )
441+ def test_vectorize_cast (extent , vec_str , target ):
418442 @I .ir_module
419443 class Before :
420444 @T .prim_func
@@ -428,8 +452,9 @@ class After:
428452 def main (A : T .Buffer ((25 ,), "int32" ), B : T .Buffer ((25 ,), "float32" )):
429453 A [T .Ramp (0 , 1 , extent )] = T .Cast (vec_str , B [T .Ramp (0 , 1 , extent )])
430454
431- mod = tvm .tir .transform .VectorizeLoop ()(Before )
432- tvm .ir .assert_structural_equal (mod , After )
455+ with tvm .target .Target (target ):
456+ mod = tvm .tir .transform .VectorizeLoop ()(Before )
457+ tvm .ir .assert_structural_equal (mod , After )
433458
434459
435460def test_illegal_extent ():
@@ -441,10 +466,27 @@ def main(A: T.Buffer((25,), "int32")):
441466 for j in T .vectorized (n ):
442467 A [j ] = 3
443468
444- error_msg = f"Invalid expression for scalable lanes n "
469+ error_msg = f"Failed to vectorize loop with extent n for target \\ (nullptr \\ ) "
445470 with pytest .raises (tvm .error .InternalError , match = error_msg ):
446471 tvm .tir .transform .VectorizeLoop ()(Mod )
447472
448473
474+ def test_illegal_vscale_in_non_sve_compilation ():
475+ @I .ir_module
476+ class Mod :
477+ @T .prim_func
478+ def main (A : T .Buffer ((16 ,), "float32" )):
479+ for j in T .vectorized (0 , 4 * T .vscale ()):
480+ A [j ] = 13
481+
482+ msg = (
483+ f"Failed to vectorize loop with extent T.vscale\\ (\\ ) \\ * 4 for target "
484+ f"llvm -keys=cpu -mtriple=x86_64-linux-gnu"
485+ )
486+ with tvm .target .Target (simple_target ):
487+ with pytest .raises (tvm .error .InternalError , match = msg ):
488+ tvm .tir .transform .VectorizeLoop ()(Mod )
489+
490+
449491if __name__ == "__main__" :
450492 tvm .testing .main ()
0 commit comments