Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
const char* name = str->data;
int64_t size = str->size;
// Case 2. string
if (size > 2 && name[0] == '"' && name[size - 1] == '"') {
if (size >= 2 && name[0] == '"' && name[size - 1] == '"') {
results.push_back(String(std::string(name + 1, size - 2)));
continue;
}
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_tir_schedule_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def _make_enter_postproc():
)


def _make_annotate(block: BlockRV, annotation: str):
return Instruction(
kind=InstructionKind.get("Annotate"),
inputs=[block, annotation],
attrs=["meta_schedule.auto_tensorize"],
outputs=[],
)


def _make_trace_1(b0, l1, l2): # pylint: disable=invalid-name
return Trace(
insts=[
Expand Down Expand Up @@ -275,5 +284,53 @@ def test_apply_json_to_schedule_1():
tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])


def _test_apply_annotation_trace_from_json(annotation: str):
"""Test applying an annotation works without crashing.

Designed to handle some previously failing edge cases like the
empty string.
"""
b0 = BlockRV()
trace = Trace(
insts=[
_make_get_block(name="B", output=b0),
_make_annotate(block=b0, annotation=annotation),
],
decisions={},
)
json_obj = trace.as_json()
sch = tir.Schedule(elementwise, debug_mask="all")
Trace.apply_json_to_schedule(json_obj, sch)

@T.prim_func
def elementwise_expected(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.alloc_buffer((128, 128))
C = T.match_buffer(c, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
T.block_attr({"meta_schedule.auto_tensorize":annotation})
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0

tvm.ir.assert_structural_equal(elementwise_expected, sch.mod["main"])

def test_apply_annotation_from_json():
# Something reasonable
_test_apply_annotation_trace_from_json("SSRSSR")

# The empty string
_test_apply_annotation_trace_from_json("")

# A string of two quotation marks
_test_apply_annotation_trace_from_json('""')

# A string of one quotation mark
_test_apply_annotation_trace_from_json('"')

if __name__ == "__main__":
tvm.testing.main()