2222import tvm .testing
2323from tvm import tir
2424from tvm .script import tir as T
25- from tvm .tir .schedule .testing import verify_trace_roundtrip , assert_structural_equal_gs
25+ from tvm .tir .schedule .testing import (
26+ verify_trace_roundtrip ,
27+ assert_structural_equal_ignore_global_symbol ,
28+ )
2629
2730# pylint: disable=no-member,invalid-name,unused-variable
2831
@@ -1284,7 +1287,7 @@ def test_cache_read_elementwise(use_block_name):
12841287 assert sch .get (cached_b ) == sch .get (sch .get_block ("B_local" ))
12851288 assert sch .get (block_b ) == sch .get (sch .get_block ("B" ))
12861289 assert sch .get (block_c ) == sch .get (sch .get_block ("C" ))
1287- assert_structural_equal_gs (cache_read_elementwise , sch .mod ["main" ])
1290+ assert_structural_equal_ignore_global_symbol (cache_read_elementwise , sch .mod ["main" ])
12881291 verify_trace_roundtrip (sch = sch , mod = elementwise )
12891292
12901293
@@ -1294,39 +1297,39 @@ def test_cache_read_under_scope(use_block_name):
12941297 block_c = "C" if use_block_name else sch .get_block ("C" )
12951298 sch .cache_read (block_b , 0 , "local" )
12961299 sch .cache_read (block_c , 0 , "global" )
1297- assert_structural_equal_gs (cache_read_under_scope , sch .mod ["main" ])
1300+ assert_structural_equal_ignore_global_symbol (cache_read_under_scope , sch .mod ["main" ])
12981301 verify_trace_roundtrip (sch = sch , mod = access_under_scope )
12991302
13001303
13011304def test_cache_read_opaque_access (use_block_name ):
13021305 sch = tir .Schedule (opaque_access , debug_mask = "all" )
13031306 block = "load_store" if use_block_name else sch .get_block ("load_store" )
13041307 sch .cache_read (block , 0 , "global" )
1305- assert_structural_equal_gs (cache_read_opaque_access , sch .mod ["main" ])
1308+ assert_structural_equal_ignore_global_symbol (cache_read_opaque_access , sch .mod ["main" ])
13061309 verify_trace_roundtrip (sch = sch , mod = opaque_access )
13071310
13081311
13091312def test_cache_read_location (use_block_name ):
13101313 sch = tir .Schedule (func_multi_consumer , debug_mask = "all" )
13111314 block_b = "B" if use_block_name else sch .get_block ("B" )
13121315 sch .cache_read (block_b , 0 , "global" )
1313- assert_structural_equal_gs (cache_read_multi_consumer , sch .mod ["main" ])
1316+ assert_structural_equal_ignore_global_symbol (cache_read_multi_consumer , sch .mod ["main" ])
13141317 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
13151318
13161319 # Test that specific consumer block targeting works.
13171320 sch = tir .Schedule (func_multi_consumer , debug_mask = "all" )
13181321 block_b = "B" if use_block_name else sch .get_block ("B" )
13191322 block_c = "C" if use_block_name else sch .get_block ("C" )
13201323 sch .cache_read (block_b , 0 , "global" , consumer_blocks = [block_c ])
1321- assert_structural_equal_gs (cache_read_multi_consumer_target , sch .mod ["main" ])
1324+ assert_structural_equal_ignore_global_symbol (cache_read_multi_consumer_target , sch .mod ["main" ])
13221325 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
13231326
13241327 # Also test setting multiple consumers yields same result as unspecified.
13251328 sch = tir .Schedule (func_multi_consumer , debug_mask = "all" )
13261329 block_b = "B" if use_block_name else sch .get_block ("B" )
13271330 block_c = "C" if use_block_name else sch .get_block ("C" )
13281331 sch .cache_read (block_b , 0 , "global" , consumer_blocks = [block_b , block_c ])
1329- assert_structural_equal_gs (cache_read_multi_consumer , sch .mod ["main" ])
1332+ assert_structural_equal_ignore_global_symbol (cache_read_multi_consumer , sch .mod ["main" ])
13301333 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
13311334
13321335
@@ -1335,23 +1338,23 @@ def test_continuous_cache_read(use_block_name):
13351338 block_c = "C" if use_block_name else sch .get_block ("C" )
13361339 sch .cache_read (block_c , 0 , "shared" )
13371340 sch .cache_read (block_c , 0 , "local" )
1338- assert_structural_equal_gs (continuous_cache_read , sch .mod ["main" ])
1341+ assert_structural_equal_ignore_global_symbol (continuous_cache_read , sch .mod ["main" ])
13391342 verify_trace_roundtrip (sch = sch , mod = elementwise )
13401343
13411344
13421345def test_cache_read_with_block_predicate (use_block_name ):
13431346 sch = tir .Schedule (func_with_block_predicate , debug_mask = "all" )
13441347 block = "consumer" if use_block_name else sch .get_block ("consumer" )
13451348 sch .cache_read (block , 0 , "shared" )
1346- assert_structural_equal_gs (block_predicate_cache_read , sch .mod ["main" ])
1349+ assert_structural_equal_ignore_global_symbol (block_predicate_cache_read , sch .mod ["main" ])
13471350 verify_trace_roundtrip (sch = sch , mod = func_with_block_predicate )
13481351
13491352
13501353def test_cache_read_non_int32_shape (use_block_name ):
13511354 sch = tir .Schedule (elementwise_shape_int64 , debug_mask = "all" )
13521355 block_b = "B" if use_block_name else sch .get_block ("B" )
13531356 sch .cache_read (block_b , 0 , "global" )
1354- assert_structural_equal_gs (cache_read_shape_int64 , sch .mod ["main" ])
1357+ assert_structural_equal_ignore_global_symbol (cache_read_shape_int64 , sch .mod ["main" ])
13551358 verify_trace_roundtrip (sch = sch , mod = elementwise_shape_int64 )
13561359
13571360
@@ -1380,7 +1383,7 @@ def test_inplace_cache_read():
13801383 sch = tvm .tir .Schedule (inplace_func , debug_mask = "all" )
13811384 block = sch .get_block ("copy_in" )
13821385 sch .cache_read (block , 0 , "local" , [block ])
1383- assert_structural_equal_gs (cache_read_inplace , sch .mod ["main" ])
1386+ assert_structural_equal_ignore_global_symbol (cache_read_inplace , sch .mod ["main" ])
13841387 verify_trace_roundtrip (sch = sch , mod = inplace_func )
13851388
13861389
@@ -1393,15 +1396,15 @@ def test_cache_inplace():
13931396 block = sch .cache_read (blocks [0 ], 0 , "global" , [blocks [0 ]])
13941397 block = sch .cache_write (blocks [1 ], 0 , "global" )
13951398
1396- assert_structural_equal_gs (cache_inplace_buffer , sch .mod ["main" ])
1399+ assert_structural_equal_ignore_global_symbol (cache_inplace_buffer , sch .mod ["main" ])
13971400 verify_trace_roundtrip (sch = sch , mod = inplace_call , debug_mask = debug_mask )
13981401
13991402
14001403def test_cache_read_nested_seq (use_block_name ):
14011404 sch = tir .Schedule (func_nested_seq , debug_mask = "all" )
14021405 block_c = "C" if use_block_name else sch .get_block ("C" )
14031406 sch .cache_read (block_c , 0 , "global" , consumer_blocks = [block_c ])
1404- assert_structural_equal_gs (cache_read_nested_seq_target , sch .mod ["main" ])
1407+ assert_structural_equal_ignore_global_symbol (cache_read_nested_seq_target , sch .mod ["main" ])
14051408 verify_trace_roundtrip (sch = sch , mod = func_nested_seq )
14061409
14071410
@@ -1418,7 +1421,7 @@ def test_cache_write_elementwise(use_block_name):
14181421 assert sch .get (cached_c ) == sch .get (sch .get_block ("C_global" ))
14191422 assert sch .get (block_b ) == sch .get (sch .get_block ("B" ))
14201423 assert sch .get (block_c ) == sch .get (sch .get_block ("C" ))
1421- assert_structural_equal_gs (cache_write_elementwise , sch .mod ["main" ])
1424+ assert_structural_equal_ignore_global_symbol (cache_write_elementwise , sch .mod ["main" ])
14221425 verify_trace_roundtrip (sch = sch , mod = elementwise )
14231426
14241427
@@ -1430,7 +1433,7 @@ def test_cache_write_under_scope(use_block_name):
14301433 sch .cache_write (block_a , 0 , "local" )
14311434 sch .cache_write (block_b , 0 , "global" )
14321435 sch .cache_write (block_scope , 0 , "global" )
1433- assert_structural_equal_gs (cache_write_under_scope , sch .mod ["main" ])
1436+ assert_structural_equal_ignore_global_symbol (cache_write_under_scope , sch .mod ["main" ])
14341437 verify_trace_roundtrip (sch = sch , mod = access_under_scope )
14351438
14361439
@@ -1442,15 +1445,15 @@ def test_cache_write_opaque_access(use_block_name):
14421445 sch .cache_write (block_store , 0 , "global" )
14431446 sch .cache_write (block_opaque , 0 , "global" )
14441447 sch .cache_write (block_match_buffer , 0 , "global" )
1445- assert_structural_equal_gs (cache_write_opaque_access , sch .mod ["main" ])
1448+ assert_structural_equal_ignore_global_symbol (cache_write_opaque_access , sch .mod ["main" ])
14461449 verify_trace_roundtrip (sch = sch , mod = opaque_access )
14471450
14481451
14491452def test_cache_write_location (use_block_name ):
14501453 sch = tir .Schedule (func_multi_consumer , debug_mask = "all" )
14511454 block_a = "A" if use_block_name else sch .get_block ("A" )
14521455 sch .cache_write (block_a , 0 , "global" )
1453- assert_structural_equal_gs (cache_write_multi_consumer , sch .mod ["main" ])
1456+ assert_structural_equal_ignore_global_symbol (cache_write_multi_consumer , sch .mod ["main" ])
14541457 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
14551458
14561459 # Test that specific consumer block targeting works.
@@ -1459,7 +1462,9 @@ def test_cache_write_location(use_block_name):
14591462 block_a = "A" if use_block_name else sch .get_block ("A" )
14601463 block_b = "B" if use_block_name else sch .get_block ("B" )
14611464 sch .cache_write (block_a , 0 , "global" , consumer_blocks = [block_b ])
1462- assert_structural_equal_gs (cache_write_multi_consumer_B_consume_cache , sch .mod ["main" ])
1465+ assert_structural_equal_ignore_global_symbol (
1466+ cache_write_multi_consumer_B_consume_cache , sch .mod ["main" ]
1467+ )
14631468 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
14641469
14651470 # Test that specific consumer block targeting works.
@@ -1468,7 +1473,9 @@ def test_cache_write_location(use_block_name):
14681473 block_a = "A" if use_block_name else sch .get_block ("A" )
14691474 block_c = "C" if use_block_name else sch .get_block ("C" )
14701475 sch .cache_write (block_a , 0 , "global" , consumer_blocks = [block_c ])
1471- assert_structural_equal_gs (cache_write_multi_consumer_C_consume_cache , sch .mod ["main" ])
1476+ assert_structural_equal_ignore_global_symbol (
1477+ cache_write_multi_consumer_C_consume_cache , sch .mod ["main" ]
1478+ )
14721479 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
14731480
14741481 # Test that specific consumer block targeting works.
@@ -1478,7 +1485,9 @@ def test_cache_write_location(use_block_name):
14781485 block_b = "B" if use_block_name else sch .get_block ("B" )
14791486 block_c = "C" if use_block_name else sch .get_block ("C" )
14801487 sch .cache_write (block_a , 0 , "global" , consumer_blocks = [block_b , block_c ])
1481- assert_structural_equal_gs (cache_write_multi_consumer_all_consume_cache , sch .mod ["main" ])
1488+ assert_structural_equal_ignore_global_symbol (
1489+ cache_write_multi_consumer_all_consume_cache , sch .mod ["main" ]
1490+ )
14821491 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
14831492
14841493
@@ -1487,7 +1496,7 @@ def test_continuous_cache_write(use_block_name):
14871496 block_b = "B" if use_block_name else sch .get_block ("B" )
14881497 sch .cache_write (block_b , 0 , "shared" )
14891498 sch .cache_write (block_b , 0 , "local" )
1490- assert_structural_equal_gs (continuous_cache_write , sch .mod ["main" ])
1499+ assert_structural_equal_ignore_global_symbol (continuous_cache_write , sch .mod ["main" ])
14911500 verify_trace_roundtrip (sch = sch , mod = elementwise )
14921501
14931502
@@ -1496,13 +1505,17 @@ def test_cache_write_with_block_predicate(use_block_name):
14961505 sch = tir .Schedule (func_with_block_predicate , debug_mask = "all" )
14971506 block = "producer" if use_block_name else sch .get_block ("producer" )
14981507 sch .cache_write (block , 0 , "shared" )
1499- assert_structural_equal_gs (block_predicate_cache_write_intermediate_buf , sch .mod ["main" ])
1508+ assert_structural_equal_ignore_global_symbol (
1509+ block_predicate_cache_write_intermediate_buf , sch .mod ["main" ]
1510+ )
15001511 verify_trace_roundtrip (sch = sch , mod = func_with_block_predicate )
15011512 # cache write for external buffer
15021513 sch = tir .Schedule (func_with_block_predicate , debug_mask = "all" )
15031514 block = "consumer" if use_block_name else sch .get_block ("consumer" )
15041515 sch .cache_write (block , 0 , "shared" )
1505- assert_structural_equal_gs (block_predicate_cache_write_output_buf , sch .mod ["main" ])
1516+ assert_structural_equal_ignore_global_symbol (
1517+ block_predicate_cache_write_output_buf , sch .mod ["main" ]
1518+ )
15061519 verify_trace_roundtrip (sch = sch , mod = func_with_block_predicate )
15071520
15081521
@@ -1601,21 +1614,21 @@ def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float1
16011614
16021615 after = sch .mod ["main" ]
16031616
1604- assert_structural_equal_gs (expected , after )
1617+ assert_structural_equal_ignore_global_symbol (expected , after )
16051618 verify_trace_roundtrip (sch = sch , mod = before )
16061619
16071620
16081621def test_reindex_cache_read ():
16091622 sch = tir .Schedule (elementwise , debug_mask = "all" )
16101623 sch .reindex_cache_read ("C" , 0 , "shared" , lambda i , j : (j , i // 2 , i % 2 ))
1611- assert_structural_equal_gs (elementwise_reindex_cache_read , sch .mod ["main" ])
1624+ assert_structural_equal_ignore_global_symbol (elementwise_reindex_cache_read , sch .mod ["main" ])
16121625 verify_trace_roundtrip (sch = sch , mod = elementwise )
16131626
16141627
16151628def test_reindex_cache_read_multi_consumer ():
16161629 sch = tir .Schedule (func_multi_consumer )
16171630 sch .reindex_cache_read ("B" , 0 , "shared" , lambda i : (i // 32 , i % 32 ))
1618- assert_structural_equal_gs (reindex_cache_read_multi_consumer , sch .mod ["main" ])
1631+ assert_structural_equal_ignore_global_symbol (reindex_cache_read_multi_consumer , sch .mod ["main" ])
16191632 # NOTE(zihao): we do not verify trace roundtrip because of in set analysis issues.
16201633
16211634
@@ -1639,16 +1652,16 @@ def test_reindex_cache_read_failed_not_single_point():
16391652def test_reindex_cache_write ():
16401653 sch = tir .Schedule (elementwise , debug_mask = "all" )
16411654 sch .reindex_cache_write ("B" , 0 , "shared" , lambda i , j : (j , i ))
1642- assert_structural_equal_gs (elementwise_reindex_cache_write , sch .mod ["main" ])
1655+ assert_structural_equal_ignore_global_symbol (elementwise_reindex_cache_write , sch .mod ["main" ])
16431656 verify_trace_roundtrip (sch = sch , mod = elementwise )
16441657
16451658
16461659def test_reindex_cache_write_reduce ():
16471660 sch = tir .Schedule (reduce , debug_mask = "all" )
16481661 sch .reindex_cache_write ("B" , 0 , "shared" , lambda i , j , k , l : (j , i , k ))
1649- assert_structural_equal_gs (reduce_reindex_cache_write_0 , sch .mod ["main" ])
1662+ assert_structural_equal_ignore_global_symbol (reduce_reindex_cache_write_0 , sch .mod ["main" ])
16501663 sch .reindex_cache_write ("C" , 0 , "shared" , lambda i , j , k : [j , i ])
1651- assert_structural_equal_gs (reduce_reindex_cache_write_1 , sch .mod ["main" ])
1664+ assert_structural_equal_ignore_global_symbol (reduce_reindex_cache_write_1 , sch .mod ["main" ])
16521665 verify_trace_roundtrip (sch = sch , mod = reduce )
16531666
16541667
@@ -1673,15 +1686,19 @@ def test_symbolic_matmul_blocked_cache_read(use_block_name):
16731686 sch = tir .Schedule (symbolic_matmul_blocked , debug_mask = "all" )
16741687 block = "matmul" if use_block_name else sch .get_block ("matmul" )
16751688 sch .cache_read (block = block , read_buffer_index = 0 , storage_scope = "shared" )
1676- assert_structural_equal_gs (sch .mod ["main" ], symbolic_matmul_blocked_cache_read )
1689+ assert_structural_equal_ignore_global_symbol (
1690+ sch .mod ["main" ], symbolic_matmul_blocked_cache_read
1691+ )
16771692 verify_trace_roundtrip (sch = sch , mod = symbolic_matmul_blocked )
16781693
16791694
16801695def test_symbolic_matmul_blocked_cache_write (use_block_name ):
16811696 sch = tir .Schedule (symbolic_matmul_blocked , debug_mask = "all" )
16821697 block = "matmul" if use_block_name else sch .get_block ("matmul" )
16831698 sch .cache_write (block = block , write_buffer_index = 0 , storage_scope = "local" )
1684- assert_structural_equal_gs (sch .mod ["main" ], symbolic_matmul_blocked_cache_write )
1699+ assert_structural_equal_ignore_global_symbol (
1700+ sch .mod ["main" ], symbolic_matmul_blocked_cache_write
1701+ )
16851702 verify_trace_roundtrip (sch = sch , mod = symbolic_matmul_blocked )
16861703
16871704
0 commit comments