Skip to content

Commit 2f70c39

Browse files
committed
Do not use abbreviation in assert_structural_equal_gs (expand name)
1 parent 64c0cf0 commit 2f70c39

29 files changed

+395
-246
lines changed

python/tvm/tir/schedule/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tvm.tir.schedule import Schedule, Trace
2525

2626

27-
def assert_structural_equal_gs(
27+
def assert_structural_equal_ignore_global_symbol(
2828
func1: PrimFunc,
2929
func2: PrimFunc,
3030
*args: Any,

tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tvm import meta_schedule as ms
2121
from tvm.script import tir as T
2222
from tvm.target import Target
23-
from tvm.tir.schedule.testing import assert_structural_equal_gs
23+
from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol
2424

2525

2626
def _target() -> Target:
@@ -202,7 +202,7 @@ def test_layout_rewrite():
202202
sch = tvm.tir.Schedule(tir_matmul, debug_mask="all")
203203
sch.enter_postproc()
204204
assert ctx.space_generator.postprocs[0].apply(sch)
205-
assert_structural_equal_gs(sch.mod["main"], rewritten_tir_matmul)
205+
assert_structural_equal_ignore_global_symbol(sch.mod["main"], rewritten_tir_matmul)
206206

207207

208208
# fmt: off

tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll
2121
from tvm.script import tir as T
2222
from tvm.tir.schedule import Schedule
23-
from tvm.tir.schedule.testing import assert_structural_equal_gs
23+
from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol
2424

2525
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
2626
# fmt: off
@@ -196,14 +196,14 @@ def test_vectorize_inner_loop():
196196
sch = Schedule(before_matmul_vectorize)
197197
rule = RewriteParallelVectorizeUnroll()
198198
assert rule.apply(sch)
199-
assert_structural_equal_gs(sch.mod["main"], after_matmul_vectorize)
199+
assert_structural_equal_ignore_global_symbol(sch.mod["main"], after_matmul_vectorize)
200200

201201

202202
def test_parallel_vectorize_add():
203203
sch = Schedule(before_postproc_add)
204204
rule = RewriteParallelVectorizeUnroll()
205205
assert rule.apply(sch)
206-
assert_structural_equal_gs(sch.mod["main"], after_postproc_add)
206+
assert_structural_equal_ignore_global_symbol(sch.mod["main"], after_postproc_add)
207207

208208

209209
def test_no_unroll_for_spatial_block():
@@ -265,7 +265,7 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo
265265
sch = Schedule(layer_norm)
266266
assert postproc.apply(sch)
267267
mod = tvm.tir.transform.Simplify()(sch.mod)
268-
assert_structural_equal_gs(mod["main"], expected)
268+
assert_structural_equal_ignore_global_symbol(mod["main"], expected)
269269

270270

271271
if __name__ == "__main__":

tests/python/unittest/test_tir_schedule_cache_read_write.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import tvm.testing
2323
from tvm import tir
2424
from tvm.script import tir as T
25-
from tvm.tir.schedule.testing import verify_trace_roundtrip, assert_structural_equal_gs
25+
from tvm.tir.schedule.testing import (
26+
verify_trace_roundtrip,
27+
assert_structural_equal_ignore_global_symbol,
28+
)
2629

2730
# pylint: disable=no-member,invalid-name,unused-variable
2831

@@ -1284,7 +1287,7 @@ def test_cache_read_elementwise(use_block_name):
12841287
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
12851288
assert sch.get(block_b) == sch.get(sch.get_block("B"))
12861289
assert sch.get(block_c) == sch.get(sch.get_block("C"))
1287-
assert_structural_equal_gs(cache_read_elementwise, sch.mod["main"])
1290+
assert_structural_equal_ignore_global_symbol(cache_read_elementwise, sch.mod["main"])
12881291
verify_trace_roundtrip(sch=sch, mod=elementwise)
12891292

12901293

@@ -1294,39 +1297,39 @@ def test_cache_read_under_scope(use_block_name):
12941297
block_c = "C" if use_block_name else sch.get_block("C")
12951298
sch.cache_read(block_b, 0, "local")
12961299
sch.cache_read(block_c, 0, "global")
1297-
assert_structural_equal_gs(cache_read_under_scope, sch.mod["main"])
1300+
assert_structural_equal_ignore_global_symbol(cache_read_under_scope, sch.mod["main"])
12981301
verify_trace_roundtrip(sch=sch, mod=access_under_scope)
12991302

13001303

13011304
def test_cache_read_opaque_access(use_block_name):
13021305
sch = tir.Schedule(opaque_access, debug_mask="all")
13031306
block = "load_store" if use_block_name else sch.get_block("load_store")
13041307
sch.cache_read(block, 0, "global")
1305-
assert_structural_equal_gs(cache_read_opaque_access, sch.mod["main"])
1308+
assert_structural_equal_ignore_global_symbol(cache_read_opaque_access, sch.mod["main"])
13061309
verify_trace_roundtrip(sch=sch, mod=opaque_access)
13071310

13081311

13091312
def test_cache_read_location(use_block_name):
13101313
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
13111314
block_b = "B" if use_block_name else sch.get_block("B")
13121315
sch.cache_read(block_b, 0, "global")
1313-
assert_structural_equal_gs(cache_read_multi_consumer, sch.mod["main"])
1316+
assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"])
13141317
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
13151318

13161319
# Test that specific consumer block targeting works.
13171320
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
13181321
block_b = "B" if use_block_name else sch.get_block("B")
13191322
block_c = "C" if use_block_name else sch.get_block("C")
13201323
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_c])
1321-
assert_structural_equal_gs(cache_read_multi_consumer_target, sch.mod["main"])
1324+
assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer_target, sch.mod["main"])
13221325
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
13231326

13241327
# Also test setting multiple consumers yields same result as unspecified.
13251328
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
13261329
block_b = "B" if use_block_name else sch.get_block("B")
13271330
block_c = "C" if use_block_name else sch.get_block("C")
13281331
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_b, block_c])
1329-
assert_structural_equal_gs(cache_read_multi_consumer, sch.mod["main"])
1332+
assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"])
13301333
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
13311334

13321335

@@ -1335,23 +1338,23 @@ def test_continuous_cache_read(use_block_name):
13351338
block_c = "C" if use_block_name else sch.get_block("C")
13361339
sch.cache_read(block_c, 0, "shared")
13371340
sch.cache_read(block_c, 0, "local")
1338-
assert_structural_equal_gs(continuous_cache_read, sch.mod["main"])
1341+
assert_structural_equal_ignore_global_symbol(continuous_cache_read, sch.mod["main"])
13391342
verify_trace_roundtrip(sch=sch, mod=elementwise)
13401343

13411344

13421345
def test_cache_read_with_block_predicate(use_block_name):
13431346
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
13441347
block = "consumer" if use_block_name else sch.get_block("consumer")
13451348
sch.cache_read(block, 0, "shared")
1346-
assert_structural_equal_gs(block_predicate_cache_read, sch.mod["main"])
1349+
assert_structural_equal_ignore_global_symbol(block_predicate_cache_read, sch.mod["main"])
13471350
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
13481351

13491352

13501353
def test_cache_read_non_int32_shape(use_block_name):
13511354
sch = tir.Schedule(elementwise_shape_int64, debug_mask="all")
13521355
block_b = "B" if use_block_name else sch.get_block("B")
13531356
sch.cache_read(block_b, 0, "global")
1354-
assert_structural_equal_gs(cache_read_shape_int64, sch.mod["main"])
1357+
assert_structural_equal_ignore_global_symbol(cache_read_shape_int64, sch.mod["main"])
13551358
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)
13561359

13571360

@@ -1380,7 +1383,7 @@ def test_inplace_cache_read():
13801383
sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
13811384
block = sch.get_block("copy_in")
13821385
sch.cache_read(block, 0, "local", [block])
1383-
assert_structural_equal_gs(cache_read_inplace, sch.mod["main"])
1386+
assert_structural_equal_ignore_global_symbol(cache_read_inplace, sch.mod["main"])
13841387
verify_trace_roundtrip(sch=sch, mod=inplace_func)
13851388

13861389

@@ -1393,15 +1396,15 @@ def test_cache_inplace():
13931396
block = sch.cache_read(blocks[0], 0, "global", [blocks[0]])
13941397
block = sch.cache_write(blocks[1], 0, "global")
13951398

1396-
assert_structural_equal_gs(cache_inplace_buffer, sch.mod["main"])
1399+
assert_structural_equal_ignore_global_symbol(cache_inplace_buffer, sch.mod["main"])
13971400
verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask)
13981401

13991402

14001403
def test_cache_read_nested_seq(use_block_name):
14011404
sch = tir.Schedule(func_nested_seq, debug_mask="all")
14021405
block_c = "C" if use_block_name else sch.get_block("C")
14031406
sch.cache_read(block_c, 0, "global", consumer_blocks=[block_c])
1404-
assert_structural_equal_gs(cache_read_nested_seq_target, sch.mod["main"])
1407+
assert_structural_equal_ignore_global_symbol(cache_read_nested_seq_target, sch.mod["main"])
14051408
verify_trace_roundtrip(sch=sch, mod=func_nested_seq)
14061409

14071410

@@ -1418,7 +1421,7 @@ def test_cache_write_elementwise(use_block_name):
14181421
assert sch.get(cached_c) == sch.get(sch.get_block("C_global"))
14191422
assert sch.get(block_b) == sch.get(sch.get_block("B"))
14201423
assert sch.get(block_c) == sch.get(sch.get_block("C"))
1421-
assert_structural_equal_gs(cache_write_elementwise, sch.mod["main"])
1424+
assert_structural_equal_ignore_global_symbol(cache_write_elementwise, sch.mod["main"])
14221425
verify_trace_roundtrip(sch=sch, mod=elementwise)
14231426

14241427

@@ -1430,7 +1433,7 @@ def test_cache_write_under_scope(use_block_name):
14301433
sch.cache_write(block_a, 0, "local")
14311434
sch.cache_write(block_b, 0, "global")
14321435
sch.cache_write(block_scope, 0, "global")
1433-
assert_structural_equal_gs(cache_write_under_scope, sch.mod["main"])
1436+
assert_structural_equal_ignore_global_symbol(cache_write_under_scope, sch.mod["main"])
14341437
verify_trace_roundtrip(sch=sch, mod=access_under_scope)
14351438

14361439

@@ -1442,15 +1445,15 @@ def test_cache_write_opaque_access(use_block_name):
14421445
sch.cache_write(block_store, 0, "global")
14431446
sch.cache_write(block_opaque, 0, "global")
14441447
sch.cache_write(block_match_buffer, 0, "global")
1445-
assert_structural_equal_gs(cache_write_opaque_access, sch.mod["main"])
1448+
assert_structural_equal_ignore_global_symbol(cache_write_opaque_access, sch.mod["main"])
14461449
verify_trace_roundtrip(sch=sch, mod=opaque_access)
14471450

14481451

14491452
def test_cache_write_location(use_block_name):
14501453
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
14511454
block_a = "A" if use_block_name else sch.get_block("A")
14521455
sch.cache_write(block_a, 0, "global")
1453-
assert_structural_equal_gs(cache_write_multi_consumer, sch.mod["main"])
1456+
assert_structural_equal_ignore_global_symbol(cache_write_multi_consumer, sch.mod["main"])
14541457
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
14551458

14561459
# Test that specific consumer block targeting works.
@@ -1459,7 +1462,9 @@ def test_cache_write_location(use_block_name):
14591462
block_a = "A" if use_block_name else sch.get_block("A")
14601463
block_b = "B" if use_block_name else sch.get_block("B")
14611464
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b])
1462-
assert_structural_equal_gs(cache_write_multi_consumer_B_consume_cache, sch.mod["main"])
1465+
assert_structural_equal_ignore_global_symbol(
1466+
cache_write_multi_consumer_B_consume_cache, sch.mod["main"]
1467+
)
14631468
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
14641469

14651470
# Test that specific consumer block targeting works.
@@ -1468,7 +1473,9 @@ def test_cache_write_location(use_block_name):
14681473
block_a = "A" if use_block_name else sch.get_block("A")
14691474
block_c = "C" if use_block_name else sch.get_block("C")
14701475
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_c])
1471-
assert_structural_equal_gs(cache_write_multi_consumer_C_consume_cache, sch.mod["main"])
1476+
assert_structural_equal_ignore_global_symbol(
1477+
cache_write_multi_consumer_C_consume_cache, sch.mod["main"]
1478+
)
14721479
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
14731480

14741481
# Test that specific consumer block targeting works.
@@ -1478,7 +1485,9 @@ def test_cache_write_location(use_block_name):
14781485
block_b = "B" if use_block_name else sch.get_block("B")
14791486
block_c = "C" if use_block_name else sch.get_block("C")
14801487
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b, block_c])
1481-
assert_structural_equal_gs(cache_write_multi_consumer_all_consume_cache, sch.mod["main"])
1488+
assert_structural_equal_ignore_global_symbol(
1489+
cache_write_multi_consumer_all_consume_cache, sch.mod["main"]
1490+
)
14821491
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
14831492

14841493

@@ -1487,7 +1496,7 @@ def test_continuous_cache_write(use_block_name):
14871496
block_b = "B" if use_block_name else sch.get_block("B")
14881497
sch.cache_write(block_b, 0, "shared")
14891498
sch.cache_write(block_b, 0, "local")
1490-
assert_structural_equal_gs(continuous_cache_write, sch.mod["main"])
1499+
assert_structural_equal_ignore_global_symbol(continuous_cache_write, sch.mod["main"])
14911500
verify_trace_roundtrip(sch=sch, mod=elementwise)
14921501

14931502

@@ -1496,13 +1505,17 @@ def test_cache_write_with_block_predicate(use_block_name):
14961505
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
14971506
block = "producer" if use_block_name else sch.get_block("producer")
14981507
sch.cache_write(block, 0, "shared")
1499-
assert_structural_equal_gs(block_predicate_cache_write_intermediate_buf, sch.mod["main"])
1508+
assert_structural_equal_ignore_global_symbol(
1509+
block_predicate_cache_write_intermediate_buf, sch.mod["main"]
1510+
)
15001511
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
15011512
# cache write for external buffer
15021513
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
15031514
block = "consumer" if use_block_name else sch.get_block("consumer")
15041515
sch.cache_write(block, 0, "shared")
1505-
assert_structural_equal_gs(block_predicate_cache_write_output_buf, sch.mod["main"])
1516+
assert_structural_equal_ignore_global_symbol(
1517+
block_predicate_cache_write_output_buf, sch.mod["main"]
1518+
)
15061519
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
15071520

15081521

@@ -1601,21 +1614,21 @@ def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float1
16011614

16021615
after = sch.mod["main"]
16031616

1604-
assert_structural_equal_gs(expected, after)
1617+
assert_structural_equal_ignore_global_symbol(expected, after)
16051618
verify_trace_roundtrip(sch=sch, mod=before)
16061619

16071620

16081621
def test_reindex_cache_read():
16091622
sch = tir.Schedule(elementwise, debug_mask="all")
16101623
sch.reindex_cache_read("C", 0, "shared", lambda i, j: (j, i // 2, i % 2))
1611-
assert_structural_equal_gs(elementwise_reindex_cache_read, sch.mod["main"])
1624+
assert_structural_equal_ignore_global_symbol(elementwise_reindex_cache_read, sch.mod["main"])
16121625
verify_trace_roundtrip(sch=sch, mod=elementwise)
16131626

16141627

16151628
def test_reindex_cache_read_multi_consumer():
16161629
sch = tir.Schedule(func_multi_consumer)
16171630
sch.reindex_cache_read("B", 0, "shared", lambda i: (i // 32, i % 32))
1618-
assert_structural_equal_gs(reindex_cache_read_multi_consumer, sch.mod["main"])
1631+
assert_structural_equal_ignore_global_symbol(reindex_cache_read_multi_consumer, sch.mod["main"])
16191632
# NOTE(zihao): we do not verify trace roundtrip because of in set analysis issues.
16201633

16211634

@@ -1639,16 +1652,16 @@ def test_reindex_cache_read_failed_not_single_point():
16391652
def test_reindex_cache_write():
16401653
sch = tir.Schedule(elementwise, debug_mask="all")
16411654
sch.reindex_cache_write("B", 0, "shared", lambda i, j: (j, i))
1642-
assert_structural_equal_gs(elementwise_reindex_cache_write, sch.mod["main"])
1655+
assert_structural_equal_ignore_global_symbol(elementwise_reindex_cache_write, sch.mod["main"])
16431656
verify_trace_roundtrip(sch=sch, mod=elementwise)
16441657

16451658

16461659
def test_reindex_cache_write_reduce():
16471660
sch = tir.Schedule(reduce, debug_mask="all")
16481661
sch.reindex_cache_write("B", 0, "shared", lambda i, j, k, l: (j, i, k))
1649-
assert_structural_equal_gs(reduce_reindex_cache_write_0, sch.mod["main"])
1662+
assert_structural_equal_ignore_global_symbol(reduce_reindex_cache_write_0, sch.mod["main"])
16501663
sch.reindex_cache_write("C", 0, "shared", lambda i, j, k: [j, i])
1651-
assert_structural_equal_gs(reduce_reindex_cache_write_1, sch.mod["main"])
1664+
assert_structural_equal_ignore_global_symbol(reduce_reindex_cache_write_1, sch.mod["main"])
16521665
verify_trace_roundtrip(sch=sch, mod=reduce)
16531666

16541667

@@ -1673,15 +1686,19 @@ def test_symbolic_matmul_blocked_cache_read(use_block_name):
16731686
sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all")
16741687
block = "matmul" if use_block_name else sch.get_block("matmul")
16751688
sch.cache_read(block=block, read_buffer_index=0, storage_scope="shared")
1676-
assert_structural_equal_gs(sch.mod["main"], symbolic_matmul_blocked_cache_read)
1689+
assert_structural_equal_ignore_global_symbol(
1690+
sch.mod["main"], symbolic_matmul_blocked_cache_read
1691+
)
16771692
verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked)
16781693

16791694

16801695
def test_symbolic_matmul_blocked_cache_write(use_block_name):
16811696
sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all")
16821697
block = "matmul" if use_block_name else sch.get_block("matmul")
16831698
sch.cache_write(block=block, write_buffer_index=0, storage_scope="local")
1684-
assert_structural_equal_gs(sch.mod["main"], symbolic_matmul_blocked_cache_write)
1699+
assert_structural_equal_ignore_global_symbol(
1700+
sch.mod["main"], symbolic_matmul_blocked_cache_write
1701+
)
16851702
verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked)
16861703

16871704

0 commit comments

Comments
 (0)