Skip to content

Commit acbcaf7

Browse files
committed
[Script] Add support for merging block annotations
This commit introduces functionality to merge block annotations in TVM script. The implementation includes: - MergeAnnotations function that recursively merges annotation dictionaries - Support for nested dictionary merging with new values overriding old ones - Error handling for conflicting annotation values - BlockAttrs function that uses the merging logic to combine multiple T.block_attr() calls within the same block The feature allows users to specify block attributes incrementally using multiple T.block_attr() calls, which will be automatically merged together.
1 parent 3c7c515 commit acbcaf7

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

src/script/ir_builder/tir/ir.cc

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,47 @@ void Writes(Array<ObjectRef> buffer_slices) {
219219
frame->writes = writes;
220220
}
221221

222+
/*! \brief Recursively merge two annotations, the new attrs will override the old ones */
223+
Map<String, Any> MergeAnnotations(const Map<String, Any>& new_attrs,
224+
const Map<String, Any>& old_attrs) {
225+
Map<String, Any> result = old_attrs;
226+
for (const auto& [key, value] : new_attrs) {
227+
auto old_value = old_attrs.Get(key);
228+
// Case 1: the key is not in the old annotations, set the key to the new value
229+
if (!old_value) {
230+
result.Set(key, value);
231+
continue;
232+
}
233+
234+
// Case 2: the key is in the old annotations
235+
// Case 2.1: both are dicts
236+
auto old_dict = old_value->try_cast<Map<String, Any>>();
237+
auto new_dict = value.try_cast<Map<String, Any>>();
238+
if (old_dict && new_dict) {
239+
// Recursively merge the two dicts
240+
auto merged_dict = MergeAnnotations(*old_dict, *new_dict);
241+
result.Set(key, merged_dict);
242+
continue;
243+
}
244+
// Case 2.2: the values are not both dicts, check if the keys are the same
245+
if (!ffi::AnyEqual()(old_value.value(), value)) {
246+
LOG(FATAL) << "ValueError: Try to merge two annotations with different values for key `"
247+
<< key << "`, previous one is " << old_value->cast<ObjectRef>() << ", new one is "
248+
<< value.cast<ObjectRef>();
249+
}
250+
}
251+
return result;
252+
}
253+
222254
void BlockAttrs(Map<String, Any> attrs) {
223255
BlockFrame frame = FindBlockFrame("T.block_attr");
224-
if (frame->annotations.defined()) {
225-
LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
256+
// Case 1: the block has no annotations, set the new annotations
257+
if (!frame->annotations.defined()) {
258+
frame->annotations = attrs;
259+
} else {
260+
// Case 2: the block has annotations, merge the new annotations with the old ones
261+
frame->annotations = MergeAnnotations(attrs, frame->annotations.value());
226262
}
227-
frame->annotations = attrs;
228263
}
229264

230265
Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,

tests/python/tvmscript/test_tvmscript_parser_tir.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,5 +544,49 @@ def expected() -> None:
544544
tvm.ir.assert_structural_equal(create_func(False), create_expected(1))
545545

546546

547+
def test_block_annotation_merge():
548+
def _to_dict(anno: tvm.ffi.container.Map):
549+
result = {}
550+
for k, v in anno.items():
551+
result[k] = _to_dict(v) if isinstance(v, tvm.ffi.container.Map) else v
552+
return result
553+
554+
@T.prim_func
555+
def func0():
556+
with T.block():
557+
T.block_attr({"key1": "block1"})
558+
T.block_attr({"key2": "block2"})
559+
T.evaluate(0)
560+
561+
assert _to_dict(func0.body.block.annotations) == {"key1": "block1", "key2": "block2"}
562+
563+
@T.prim_func
564+
def func1():
565+
with T.block():
566+
T.block_attr({"key": {"key1": "block1"}})
567+
T.block_attr({"key": {"key2": "block2"}})
568+
T.evaluate(0)
569+
570+
assert _to_dict(func1.body.block.annotations) == {"key": {"key1": "block1", "key2": "block2"}}
571+
572+
@T.prim_func
573+
def func2():
574+
with T.block():
575+
T.block_attr({"key1": "block1"})
576+
T.block_attr({"key1": "block1"})
577+
T.evaluate(0)
578+
579+
assert _to_dict(func2.body.block.annotations) == {"key1": "block1"}
580+
581+
with pytest.raises(tvm.TVMError):
582+
583+
@T.prim_func
584+
def func3():
585+
with T.block():
586+
T.block_attr({"key1": "block1"})
587+
T.block_attr({"key1": "block2"})
588+
T.evaluate(0)
589+
590+
547591
if __name__ == "__main__":
548592
tvm.testing.main()

0 commit comments

Comments
 (0)