From cd54fcc435aa1a65b5b319f79151267d2351b3dc Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 19 Jun 2025 00:01:30 +0800 Subject: [PATCH] [Script] Add support for merging block annotations This commit introduces functionality to merge block annotations in TVM script. The implementation includes: - MergeAnnotations function that recursively merges annotation dictionaries - Support for nested dictionary merging with new values overriding old ones - Error handling for conflicting annotation values - BlockAttrs function that uses the merging logic to combine multiple T.block_attr() calls within the same block The feature allows users to specify block attributes incrementally using multiple T.block_attr() calls, which will be automatically merged together. --- src/script/ir_builder/tir/ir.cc | 41 +++++++++++++++-- .../tvmscript/test_tvmscript_error_report.py | 17 +++---- .../tvmscript/test_tvmscript_parser_tir.py | 44 +++++++++++++++++++ 3 files changed, 91 insertions(+), 11 deletions(-) diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 6f73254ff2ab..831dbcdd4aa8 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -219,12 +219,47 @@ void Writes(Array buffer_slices) { frame->writes = writes; } +/*! \brief Recursively merge two annotations, the new attrs will override the old ones */ +Map MergeAnnotations(const Map& new_attrs, + const Map& old_attrs) { + Map result = old_attrs; + for (const auto& [key, value] : new_attrs) { + auto old_value = old_attrs.Get(key); + // Case 1: the key is not in the old annotations, set the key to the new value + if (!old_value) { + result.Set(key, value); + continue; + } + + // Case 2: the key is in the old annotations + // Case 2.1: both are dicts + auto old_dict = old_value->try_cast>(); + auto new_dict = value.try_cast>(); + if (old_dict && new_dict) { + // Recursively merge the two dicts + auto merged_dict = MergeAnnotations(*old_dict, *new_dict); + result.Set(key, merged_dict); + continue; + } + // Case 2.2: the values are not both dicts, check if the keys are the same + if (!ffi::AnyEqual()(old_value.value(), value)) { + LOG(FATAL) << "ValueError: Try to merge two annotations with different values for key `" + << key << "`, previous one is " << old_value->cast() << ", new one is " + << value.cast(); + } + } + return result; +} + void BlockAttrs(Map attrs) { BlockFrame frame = FindBlockFrame("T.block_attr"); - if (frame->annotations.defined()) { - LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations; + // Case 1: the block has no annotations, set the new annotations + if (!frame->annotations.defined()) { + frame->annotations = attrs; + } else { + // Case 2: the block has annotations, merge the new annotations with the old ones + frame->annotations = MergeAnnotations(attrs, frame->annotations.value()); } - frame->annotations = attrs; } Buffer AllocBuffer(Array shape, DataType dtype, Optional data, diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py b/tests/python/tvmscript/test_tvmscript_error_report.py index d8212d38854c..1cbd6af961c7 100644 --- a/tests/python/tvmscript/test_tvmscript_error_report.py +++ b/tests/python/tvmscript/test_tvmscript_error_report.py @@ -280,13 +280,6 @@ def duplicate_predicate() -> None: T.where(1) T.where(0) # error - def duplicate_annotations() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({}) - T.block_attr({}) # error - def duplicate_init() -> None: for i, j in T.grid(16, 16): with T.block(): @@ -303,12 +296,20 @@ def duplicate_axes() -> None: vi = T.axis.S(i, 16) # error T.evaluate(1.0) + def duplicate_block_attrs_with_same_key_diff_value() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"key1": "block1"}) + T.block_attr({"key1": "block2"}) # error + T.evaluate(1.0) + check_error(duplicate_reads, 7) check_error(duplicate_writes, 7) check_error(duplicate_predicate, 6) - check_error(duplicate_annotations, 6) check_error(duplicate_init, 7) check_error(duplicate_axes, 5) + check_error(duplicate_block_attrs_with_same_key_diff_value, 6) def test_opaque_access_during_complete(): diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 16b206751402..d5ee2e07729a 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -544,5 +544,49 @@ def expected() -> None: tvm.ir.assert_structural_equal(create_func(False), create_expected(1)) +def test_block_annotation_merge(): + def _to_dict(anno: tvm.ffi.container.Map): + result = {} + for k, v in anno.items(): + result[k] = _to_dict(v) if isinstance(v, tvm.ffi.container.Map) else v + return result + + @T.prim_func + def func0(): + with T.block(): + T.block_attr({"key1": "block1"}) + T.block_attr({"key2": "block2"}) + T.evaluate(0) + + assert _to_dict(func0.body.block.annotations) == {"key1": "block1", "key2": "block2"} + + @T.prim_func + def func1(): + with T.block(): + T.block_attr({"key": {"key1": "block1"}}) + T.block_attr({"key": {"key2": "block2"}}) + T.evaluate(0) + + assert _to_dict(func1.body.block.annotations) == {"key": {"key1": "block1", "key2": "block2"}} + + @T.prim_func + def func2(): + with T.block(): + T.block_attr({"key1": "block1"}) + T.block_attr({"key1": "block1"}) + T.evaluate(0) + + assert _to_dict(func2.body.block.annotations) == {"key1": "block1"} + + with pytest.raises(tvm.TVMError): + + @T.prim_func + def func3(): + with T.block(): + T.block_attr({"key1": "block1"}) + T.block_attr({"key1": "block2"}) + T.evaluate(0) + + if __name__ == "__main__": tvm.testing.main()