Skip to content

Commit a87931a

Browse files
committed
Rename show_inferable_type_annotations to show_all_struct_info
1 parent 47876a9 commit a87931a

File tree

6 files changed

+62
-21
lines changed

6 files changed

+62
-21
lines changed

include/tvm/node/script_printer.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,40 @@ class PrinterConfigNode : public Object {
7272
bool syntax_sugar = true;
7373
/*! \brief Whether variable names should include the object's address */
7474
bool show_object_address = false;
75-
/*! \brief Whether to show StructInfo that can be inferred from arguments */
76-
bool show_inferable_type_annotations = true;
75+
76+
/*! \brief In Relax, whether to show all StructInfo annotations
77+
*
78+
* If true (default), all variable bindings will be annotated with
79+
* the struct info of the variable being bound.
80+
*
81+
* If false, the annotations will only be shown when they are
82+
* required for correct parsing of the Relax function. For example,
83+
* function parameters must always have struct info annotations, but
84+
* the struct info for expressions within a function body may be inferred from their
85+
* arguments, and are therefore
86+
*
87+
* Example:
88+
*
89+
* # func.show(show_all_struct_info=True)
90+
* @R.function
91+
* def func(
92+
* A: R.Tensor((10, 20), dtype="float32"),
93+
* B: R.Tensor((10,20), dtype="float32"),
94+
* ) -> R.Tensor((10, 20), dtype="float32"):
95+
* C: R.Tensor((10,20), dtype="float32") = R.add(A, B2)
96+
* return C
97+
*
98+
* # func.show(show_all_struct_info=False)
99+
* @R.function
100+
* def func(
101+
* A: R.Tensor((10, 20), dtype="float32"),
102+
* B: R.Tensor((10,20), dtype="float32"),
103+
* ) -> R.Tensor((10, 20), dtype="float32"):
104+
* C = R.add(A, B2)
105+
* return C
106+
*/
107+
bool show_all_struct_info = true;
108+
77109
/* \brief Object path to be underlined */
78110
Array<ObjectPath> path_to_underline = Array<ObjectPath>();
79111
/*! \brief Object path to be annotated. */
@@ -99,7 +131,7 @@ class PrinterConfigNode : public Object {
99131
v->Visit("num_context_lines", &num_context_lines);
100132
v->Visit("syntax_sugar", &syntax_sugar);
101133
v->Visit("show_object_address", &show_object_address);
102-
v->Visit("show_inferable_type_annotations", &show_inferable_type_annotations);
134+
v->Visit("show_all_struct_info", &show_all_struct_info);
103135
v->Visit("path_to_underline", &path_to_underline);
104136
v->Visit("path_to_annotate", &path_to_annotate);
105137
v->Visit("obj_to_underline", &obj_to_underline);

python/tvm/runtime/script_printer.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class PrinterConfig(Object):
4444
num_context_lines: int
4545
syntax_sugar: bool
4646
show_object_address: bool
47-
show_inferable_type_annotations: bool
47+
show_all_struct_info: bool
4848
path_to_underline: Optional[List[ObjectPath]]
4949
path_to_annotate: Optional[Dict[ObjectPath, str]]
5050
obj_to_underline: Optional[List[Object]]
@@ -68,7 +68,7 @@ def __init__(
6868
num_context_lines: Optional[int] = None,
6969
syntax_sugar: bool = True,
7070
show_object_address: bool = False,
71-
show_inferable_type_annotations: bool = True,
71+
show_all_struct_info: bool = True,
7272
path_to_underline: Optional[List[ObjectPath]] = None,
7373
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
7474
obj_to_underline: Optional[List[Object]] = None,
@@ -91,7 +91,7 @@ def __init__(
9191
"num_context_lines": num_context_lines,
9292
"syntax_sugar": syntax_sugar,
9393
"show_object_address": show_object_address,
94-
"show_inferable_type_annotations": show_inferable_type_annotations,
94+
"show_all_struct_info": show_all_struct_info,
9595
"path_to_underline": path_to_underline,
9696
"path_to_annotate": path_to_annotate,
9797
"obj_to_underline": obj_to_underline,
@@ -135,7 +135,7 @@ def script(
135135
num_context_lines: int = -1,
136136
syntax_sugar: bool = True,
137137
show_object_address: bool = False,
138-
show_inferable_type_annotations: bool = True,
138+
show_all_struct_info: bool = True,
139139
path_to_underline: Optional[List[ObjectPath]] = None,
140140
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
141141
obj_to_underline: Optional[List[Object]] = None,
@@ -176,9 +176,10 @@ def script(
176176
Whether to output with syntax sugar, set false for complete printing.
177177
show_object_address: bool = False
178178
Whether to include the object's address as part of the TVMScript name
179-
show_inferable_type_annotations: bool = True
180-
Whether to show type annotations that can be inferred from previous
181-
annotations.
179+
show_all_struct_info: bool = True
180+
If True (default), annotate all variable bindings with the struct
181+
info of that variable. If False, only add annotations where
182+
required for unambiguous round-trip of Relax -> TVMScript -> Relax.
182183
path_to_underline : Optional[List[ObjectPath]] = None
183184
Object path to be underlined
184185
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
@@ -192,6 +193,7 @@ def script(
192193
-------
193194
script : str
194195
The TVM Script of the given TVM IR
196+
195197
"""
196198
return _script(
197199
self,
@@ -211,7 +213,7 @@ def script(
211213
num_context_lines=num_context_lines,
212214
syntax_sugar=syntax_sugar,
213215
show_object_address=show_object_address,
214-
show_inferable_type_annotations=show_inferable_type_annotations,
216+
show_all_struct_info=show_all_struct_info,
215217
path_to_underline=path_to_underline,
216218
path_to_annotate=path_to_annotate,
217219
obj_to_underline=obj_to_underline,
@@ -287,7 +289,7 @@ def show(
287289
num_context_lines: int = -1,
288290
syntax_sugar: bool = True,
289291
show_object_address: bool = False,
290-
show_inferable_type_annotations: bool = True,
292+
show_all_struct_info: bool = True,
291293
path_to_underline: Optional[List[ObjectPath]] = None,
292294
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
293295
obj_to_underline: Optional[List[Object]] = None,
@@ -351,9 +353,10 @@ def show(
351353
Whether to output with syntax sugar, set false for complete printing.
352354
show_object_address: bool = False
353355
Whether to include the object's address as part of the TVMScript name
354-
show_inferable_type_annotations: bool = True
355-
Whether to show type annotations that can be inferred from previous
356-
annotations.
356+
show_all_struct_info: bool = True
357+
If True (default), annotate all variable bindings with the struct
358+
info of that variable. If False, only add annotations where
359+
required for unambiguous round-trip of Relax -> TVMScript -> Relax.
357360
path_to_underline : Optional[List[ObjectPath]] = None
358361
Object path to be underlined
359362
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
@@ -389,7 +392,7 @@ def show(
389392
num_context_lines=num_context_lines,
390393
syntax_sugar=syntax_sugar,
391394
show_object_address=show_object_address,
392-
show_inferable_type_annotations=show_inferable_type_annotations,
395+
show_all_struct_info=show_all_struct_info,
393396
path_to_underline=path_to_underline,
394397
path_to_annotate=path_to_annotate,
395398
obj_to_underline=obj_to_underline,

src/node/script_printer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
112112
if (auto v = config_dict.Get("show_object_address")) {
113113
n->show_object_address = Downcast<IntImm>(v)->value;
114114
}
115-
if (auto v = config_dict.Get("show_inferable_type_annotations")) {
116-
n->show_inferable_type_annotations = Downcast<IntImm>(v)->value;
115+
if (auto v = config_dict.Get("show_all_struct_info")) {
116+
n->show_all_struct_info = Downcast<IntImm>(v)->value;
117117
}
118118

119119
// Checking prefixes if they are valid Python identifiers.

src/script/printer/relax/binding.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
4545
using relax::StructInfo;
4646
using relax::MatchStructInfo;
4747
Optional<ExprDoc> ann = NullOpt;
48-
if (d->cfg->show_inferable_type_annotations) {
48+
if (d->cfg->show_all_struct_info) {
4949
ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
5050
}
5151
ExprDoc rhs = Relax(d, "match_cast")

src/script/printer/relax/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v, const ObjectPath&
9191
return NullOpt;
9292
}
9393
}
94-
if (!d->cfg->show_inferable_type_annotations) {
94+
if (!d->cfg->show_all_struct_info) {
9595
Optional<relax::StructInfo> inferred_sinfo = NullOpt;
9696
if (auto opt = rhs.as<relax::Call>()) {
9797
auto call = opt.value();

tests/python/relax/test_tvmscript_printer_relax.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,12 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype
805805

806806

807807
def test_hide_inferable_struct_info():
808+
"""Redundant type annotations can be omitted
809+
810+
When `show_all_struct_info=False`, TVMScript type annotations that
811+
provide redundant struct info can be omitted.
812+
"""
813+
808814
@R.function
809815
def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")):
810816
# R.match_cast has the struct info as an argument, so it can
@@ -833,7 +839,7 @@ def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32"))
833839
return E
834840

835841
_assert_print(
836-
func.script(show_inferable_type_annotations=False),
842+
func.script(show_all_struct_info=False),
837843
"""
838844
# from tvm.script import relax as R
839845

0 commit comments

Comments
 (0)