Skip to content

Commit 165b124

Browse files
authored
[Bugfix][Op] Register attributes for unique and print (apache#248)
Attempting to use `dump_ast` on functions containing the operators `relax.unique` and `relax.print` previously crashed due to being unable to query their attributes' keys. It turned out that this was a problem with the operator attributes: They had not been registered on the Python side, so Python representation treated them as opaque TVM objects. This PR corrects this mistake.
1 parent 3add7b1 commit 165b124

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

python/tvm/relax/op/op_attrs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,13 @@ class VMAllocStorageAttrs(Attrs):
3232
@tvm._ffi.register_object("relax.attrs.VMAllocTensorAttrs")
3333
class VMAllocTensorAttrs(Attrs):
3434
"""Attributes used in VM alloc_tensor operators"""
35+
36+
37+
@tvm._ffi.register_object("relax.attrs.UniqueAttrs")
38+
class UniqueAttrs(Attrs):
39+
"""Attributes used for the unique operator"""
40+
41+
42+
@tvm._ffi.register_object("relax.attrs.PrintAttrs")
43+
class PrintAttrs(Attrs):
44+
"""Attributes used for the print operator"""

tests/python/relax/test_ast_printer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,5 +336,40 @@ def foo(x: Tensor((m, n), "float32")):
336336
assert call_tir_text in foo_str
337337

338338

339+
def test_operators():
340+
# the operator attributes need to be registered to work in the printer
341+
342+
@R.function
343+
def foo(x: Tensor):
344+
return relax.unique(x, sorted=True)
345+
346+
foo_str = strip_whitespace(
347+
dump_ast(
348+
foo,
349+
include_type_annotations=False,
350+
include_shape_annotations=False,
351+
)
352+
)
353+
# checking that the attributes are present
354+
assert '"sorted":1' in foo_str
355+
assert '"return_inverse"' in foo_str
356+
assert '"return_counts"' in foo_str
357+
assert '"dim"' in foo_str
358+
359+
@R.function
360+
def bar(x: Tensor):
361+
return relax.print(x, format="{}")
362+
363+
bar_str = strip_whitespace(
364+
dump_ast(
365+
bar,
366+
include_type_annotations=False,
367+
include_shape_annotations=False,
368+
)
369+
)
370+
print_attrs_str = strip_whitespace('{"format": "{}"}')
371+
assert print_attrs_str in bar_str
372+
373+
339374
if __name__ == "__main__":
340375
pytest.main([__file__])

0 commit comments

Comments
 (0)