Skip to content

Commit f4af81c

Browse files
[TIRScript] fix parse StringImm value in for loop annotations (#9755)
* fix parse strimm value in for annotations * flatten buffer allow runtime.String attr value * remove unused import * rebase and ensure flattened attr order
1 parent 5cb5c5b commit f4af81c

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

python/tvm/script/tir/scope_handler.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import synr
2222
import tvm.tir
23-
from tvm.runtime import Object, String
23+
from tvm.runtime import Object
2424
from tvm.ir import Span, Range
2525
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
2626

@@ -483,14 +483,8 @@ def create_loop_info(
483483
"""
484484
assert self.context and self.node, "call 'exit_scope' before 'enter_scope'"
485485
extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
486-
self.annotations: Mapping[str, Object] = {}
487-
if annotations is not None:
488-
self.annotations = {
489-
key: String(val) if isinstance(val, str) else val
490-
for key, val in annotations.items()
491-
}
492-
493-
self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations))
486+
self.annotations = annotations
487+
self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, self.annotations))
494488

495489

496490
@register

src/tir/transforms/flatten_buffer.cc

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,16 @@ class BufferFlattener : public StmtExprMutator {
9797
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
9898
}
9999
// Step 4. Handle annotations
100+
std::set<std::string> ordered_ann_keys;
100101
for (const auto& annotation : op->annotations) {
101-
const String& ann_key = annotation.first;
102-
const ObjectRef& ann_value = annotation.second;
102+
ordered_ann_keys.insert(annotation.first);
103+
}
104+
for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) {
105+
const std::string& ann_key = *it;
106+
const ObjectRef& ann_value = op->annotations.at(ann_key);
103107
if (attr::IsPragmaKey(ann_key)) {
104-
body = AttrStmt(op->loop_var, ann_key, Downcast<PrimExpr>(ann_value), std::move(body));
108+
body =
109+
AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body));
105110
}
106111
}
107112
return body;
@@ -154,6 +159,21 @@ class BufferFlattener : public StmtExprMutator {
154159
/*body=*/std::move(body));
155160
}
156161

162+
/*! \brief Convert attr value from annotation map into PrimExpr. */
163+
PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
164+
if (!obj.defined()) {
165+
return PrimExpr();
166+
} else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
167+
return GetRef<PrimExpr>(expr);
168+
} else if (const StringObj* str = obj.as<StringObj>()) {
169+
return std::move(StringImm(str->data));
170+
} else {
171+
LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey()
172+
<< " not supported";
173+
return PrimExpr();
174+
}
175+
}
176+
157177
/*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
158178
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;
159179
};

tests/python/unittest/test_tir_transform_flatten_buffer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
247247
C.data[i0 * 64 + i1 * 16 + j] = T.load("float32", B_new, i1 * 17 + j) * 2.0
248248

249249

250+
@T.prim_func
251+
def annotated_loops(a: T.handle) -> None:
252+
A = T.match_buffer(a, (16,), "float32")
253+
for i in range(0, 16, annotations={"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}):
254+
A[i] = 0.0
255+
256+
250257
def test_elementwise():
251258
_check(compacted_elementwise_func, flattened_elementwise_func)
252259

@@ -284,6 +291,20 @@ def test_lower_te():
284291
tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE
285292

286293

294+
def test_annotated_loops():
295+
mod = tvm.IRModule.from_expr(annotated_loops)
296+
mod = tvm.tir.transform.FlattenBuffer()(mod)
297+
# _check(annotated_loops, compacted_annotated_loops)
298+
attr1 = mod["main"].body
299+
attr2 = attr1.body
300+
attr3 = attr2.body
301+
assert attr1.attr_key == "pragma_1" and attr1.value == "str_value"
302+
assert attr2.attr_key == "pragma_2"
303+
tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1))
304+
assert attr3.attr_key == "pragma_3"
305+
tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))
306+
307+
287308
if __name__ == "__main__":
288309
test_elementwise()
289310
test_gpu_workload()
@@ -293,3 +314,4 @@ def test_lower_te():
293314
test_multi_alloc()
294315
test_strided_buffer()
295316
test_lower_te()
317+
test_annotated_loops()

0 commit comments

Comments
 (0)