Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,47 @@ void Writes(Array<ObjectRef> buffer_slices) {
frame->writes = writes;
}

/*! \brief Recursively merge two annotations, the new attrs will override the old ones */
Map<String, Any> MergeAnnotations(const Map<String, Any>& new_attrs,
const Map<String, Any>& old_attrs) {
Map<String, Any> result = old_attrs;
for (const auto& [key, value] : new_attrs) {
auto old_value = old_attrs.Get(key);
// Case 1: the key is not in the old annotations, set the key to the new value
if (!old_value) {
result.Set(key, value);
continue;
}

// Case 2: the key is in the old annotations
// Case 2.1: both are dicts
auto old_dict = old_value->try_cast<Map<String, Any>>();
auto new_dict = value.try_cast<Map<String, Any>>();
if (old_dict && new_dict) {
// Recursively merge the two dicts
auto merged_dict = MergeAnnotations(*old_dict, *new_dict);
result.Set(key, merged_dict);
continue;
}
// Case 2.2: the values are not both dicts, check if the keys are the same
if (!ffi::AnyEqual()(old_value.value(), value)) {
LOG(FATAL) << "ValueError: Try to merge two annotations with different values for key `"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be useful to cross check error reporting here, to make sure when error happens, we can locate the span and have right ^^^^^ to point at the code location. @Hzfengsy can you confirm?

Copy link
Member Author

@Hzfengsy Hzfengsy Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed it works well. Try with testcase:

from tvm.script import tir as T

@T.prim_func
def func3():
    with T.block():
        T.block_attr({"key1": "block1"})
        T.block_attr({"key1": "block2"})
        T.evaluate(0)

get output:

error: Try to merge two annotations with different values for key `key1`, previous one is "block1", new one is "block2"
 --> /path/to/tvm/t.py:7:9
   |  
 7 |          T.block_attr({"key1": "block2"})
   |          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I added a test case to test the error in tests/python/tvmscript/test_tvmscript_error_report.py

<< key << "`, previous one is " << old_value->cast<ObjectRef>() << ", new one is "
<< value.cast<ObjectRef>();
}
}
return result;
}

void BlockAttrs(Map<String, Any> attrs) {
BlockFrame frame = FindBlockFrame("T.block_attr");
if (frame->annotations.defined()) {
LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
// Case 1: the block has no annotations, set the new annotations
if (!frame->annotations.defined()) {
frame->annotations = attrs;
} else {
// Case 2: the block has annotations, merge the new annotations with the old ones
frame->annotations = MergeAnnotations(attrs, frame->annotations.value());
}
frame->annotations = attrs;
}

Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
Expand Down
17 changes: 9 additions & 8 deletions tests/python/tvmscript/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,6 @@ def duplicate_predicate() -> None:
T.where(1)
T.where(0) # error

def duplicate_annotations() -> None:
for i, j in T.grid(16, 16):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({})
T.block_attr({}) # error

def duplicate_init() -> None:
for i, j in T.grid(16, 16):
with T.block():
Expand All @@ -303,12 +296,20 @@ def duplicate_axes() -> None:
vi = T.axis.S(i, 16) # error
T.evaluate(1.0)

def duplicate_block_attrs_with_same_key_diff_value() -> None:
for i, j in T.grid(16, 16):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"key1": "block1"})
T.block_attr({"key1": "block2"}) # error
T.evaluate(1.0)

check_error(duplicate_reads, 7)
check_error(duplicate_writes, 7)
check_error(duplicate_predicate, 6)
check_error(duplicate_annotations, 6)
check_error(duplicate_init, 7)
check_error(duplicate_axes, 5)
check_error(duplicate_block_attrs_with_same_key_diff_value, 6)


def test_opaque_access_during_complete():
Expand Down
44 changes: 44 additions & 0 deletions tests/python/tvmscript/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,5 +544,49 @@ def expected() -> None:
tvm.ir.assert_structural_equal(create_func(False), create_expected(1))


def test_block_annotation_merge():
def _to_dict(anno: tvm.ffi.container.Map):
result = {}
for k, v in anno.items():
result[k] = _to_dict(v) if isinstance(v, tvm.ffi.container.Map) else v
return result

@T.prim_func
def func0():
with T.block():
T.block_attr({"key1": "block1"})
T.block_attr({"key2": "block2"})
T.evaluate(0)

assert _to_dict(func0.body.block.annotations) == {"key1": "block1", "key2": "block2"}

@T.prim_func
def func1():
with T.block():
T.block_attr({"key": {"key1": "block1"}})
T.block_attr({"key": {"key2": "block2"}})
T.evaluate(0)

assert _to_dict(func1.body.block.annotations) == {"key": {"key1": "block1", "key2": "block2"}}

@T.prim_func
def func2():
with T.block():
T.block_attr({"key1": "block1"})
T.block_attr({"key1": "block1"})
T.evaluate(0)

assert _to_dict(func2.body.block.annotations) == {"key1": "block1"}

with pytest.raises(tvm.TVMError):

@T.prim_func
def func3():
with T.block():
T.block_attr({"key1": "block1"})
T.block_attr({"key1": "block2"})
T.evaluate(0)


if __name__ == "__main__":
tvm.testing.main()
Loading