Skip to content

Commit c0df6f7

Browse files
committed
fix bugs
1 parent 874d41b commit c0df6f7

File tree

3 files changed

+5
-11
lines changed

3 files changed

+5
-11
lines changed

src/script/printer/tir/buffer.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,9 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
129129
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
130130
const IRDocsifier& d) {
131131
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
132-
Array<Doc> indices_doc;
133-
for (String s : {"shape", "dtype"}) {
134-
if (Optional<ExprDoc> doc = attrs.Get(s)) {
135-
indices_doc.push_back(doc.value());
136-
}
137-
}
138-
return TIR("Buffer")[indices_doc];
132+
ExprDoc shape = attrs.Get("shape").value();
133+
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
134+
return TIR("Buffer")->Call({shape, dtype}, {}, {});
139135
}
140136

141137
Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,

tests/python/unittest/test_tvmscript_printer_syntax_sugar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def func(a: T.handle, b: T.handle, c: T.handle):
7373
C = T.match_buffer(c, [128, 128, 128], dtype="uint8")
7474

7575
expected_output = """@T.prim_func
76-
def main(A: T.Buffer[(128,)], B: T.Buffer[(128, 128), "int32"], C: T.Buffer[(128, 128, 128), "uint8"]):
76+
def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128, 128), "int32"), C: T.Buffer((128, 128, 128), "uint8")):
7777
T.evaluate(0)"""
7878
_test(func, expected_output)
7979

tests/python/unittest/test_tvmscript_printer_tir.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ def test_prim_func():
5757
func,
5858
expected="""
5959
@T.prim_func
60-
def main(a: T.handle, b: T.handle):
61-
A = T.match_buffer(a, (128, 128))
62-
B = T.match_buffer(b, (256, 256))
60+
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
6361
T.evaluate(0)""",
6462
)
6563

0 commit comments

Comments
 (0)