Skip to content

Commit 292c5a2

Browse files
committed
[TVMScript][Bugfix] Check for StructInfoProxy in R.match_cast
Prior to this commit, bare `StructInfoProxy` annotations could be used to annotate variables (e.g. `var: R.Tensor`). However, they could not be used as the argument of a match cast (e.g. `R.match_cast(obj, R.Tensor)`). This breaks round-trips, as the `R.match_cast` printing generates base `StructInfoProxy` objects. This commit updates TVMScript parsing to handle bare `StructInfoProxy` annotations as an argument to `R.match_cast`.
1 parent 69995d1 commit 292c5a2

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

python/tvm/script/parser/relax/entry.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
############################## R.function ##############################
4949

50+
5051
# this formulation allows us to support having @R.function
5152
# appear as a decorator by itself or to have optional arguments
5253
# like @R.function(pure=False)
@@ -488,8 +489,31 @@ def __init__(self, value: Expr, struct_info: StructInfo) -> None:
488489

489490

490491
def match_cast(value: Expr, struct_info: StructInfo):
492+
struct_info = _normalize_struct_info(struct_info)
493+
491494
if value is None:
492495
raise ValueError("value of match_cast cannot be None")
493496
if struct_info is None:
494497
raise ValueError("struct_info of match_cast cannot be None")
495498
return MatchCastPair(value, struct_info)
499+
500+
501+
def _normalize_struct_info_proxy(annotation) -> StructInfoProxy:
502+
if annotation is None:
503+
return TupleProxy([])
504+
elif callable(annotation):
505+
return annotation()
506+
elif isinstance(annotation, StructInfoProxy):
507+
return annotation
508+
else:
509+
raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.")
510+
511+
512+
def _normalize_struct_info(
513+
struct_info, dict_globals: Optional[Dict[str, Any]] = None
514+
) -> StructInfo:
515+
if isinstance(struct_info, StructInfo):
516+
return struct_info
517+
else:
518+
proxy = _normalize_struct_info_proxy(struct_info)
519+
return proxy.as_struct_info(dict_globals)

python/tvm/script/parser/relax/parser.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@
3030
from ...ir_builder import relax as R
3131
from ...ir_builder.base import IRBuilder
3232
from .._core import Parser, dispatch, doc
33-
from .entry import MatchCastPair, StructInfoProxy, TupleProxy
33+
from .entry import (
34+
MatchCastPair,
35+
StructInfoProxy,
36+
TupleProxy,
37+
_normalize_struct_info_proxy,
38+
_normalize_struct_info,
39+
)
3440

3541

3642
def bind_assign_value(
@@ -91,13 +97,7 @@ def bind_assign_value(
9197
def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
9298
try:
9399
annotation = self.eval_expr(node)
94-
if annotation is None:
95-
return TupleProxy([])
96-
if callable(annotation):
97-
annotation = annotation()
98-
if isinstance(annotation, StructInfoProxy):
99-
return annotation
100-
raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.")
100+
return _normalize_struct_info_proxy(annotation)
101101
except Exception as err:
102102
self.report_error(node, str(err))
103103
raise err
@@ -106,7 +106,8 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
106106
def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo:
107107
var_table = self.var_table.get() if eval_str else None
108108
try:
109-
return eval_struct_info_proxy(self, node).as_struct_info(var_table)
109+
struct_info = self.eval_expr(node)
110+
return _normalize_struct_info(struct_info, var_table)
110111
except Exception as err:
111112
self.report_error(node, str(err))
112113
raise err
@@ -367,7 +368,6 @@ def visit_if(self: Parser, node: doc.If) -> None:
367368
@dispatch.register(token="relax", type_name="enter_token")
368369
def enter_token(self: Parser) -> Dict[str, Any]:
369370
def relax_call(self, *args) -> Expr:
370-
371371
args = [convert_to_expr(arg) if isinstance(arg, tuple) else arg for arg in args]
372372

373373
if all(isinstance(x, Expr) for x in args):

tests/python/tvmscript/test_tvmscript_roundtrip.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
113113
T.ramp((x_c * 32), 1, 32)
114114
] + (
115115
T.broadcast(
116-
A_1[
117-
(((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)),
118-
],
116+
A_1[(((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)),],
119117
32,
120118
)
121119
* packedB[T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32)]
@@ -4023,6 +4021,47 @@ def func(A: R.Tensor([10, 20], "float32")):
40234021
return func
40244022

40254023

4024+
def relax_match_cast_struct_info_proxy():
4025+
"""StructInfoProxy subclasses may be used as expressions
4026+
4027+
This is a regression test. The TVMScript parser allows StructInfo
4028+
to be specified using a default-constructible class
4029+
(e.g. `R.Tensor` or `R.Shape`) rather than an instance of that
4030+
class (e.g. `R.Tensor()` or `R.Shape()`). In previous
4031+
implementations, this was only handled when the `StructInfo` was
4032+
used in an annotation context. However, a `StructInfo` may also
4033+
appear as an argument, which is passed to `R.match_cast`. Use of
4034+
a default-constructible class must be handled in this context as
4035+
well.
4036+
"""
4037+
4038+
def make_ir_generator(proxy_subclass):
4039+
def inner():
4040+
@R.function
4041+
def func(A: R.Object):
4042+
B = R.match_cast(A, proxy_subclass)
4043+
return B
4044+
4045+
return func
4046+
4047+
inner.__name__ = subclass.__name__
4048+
return inner
4049+
4050+
# Not all subclasses of StructInfoProxy are default-constructible.
4051+
# This list is a subset of `StructInfoProxy.__subclasses__()`,
4052+
# excluding `PrimProxy` and `DTensorProxy`.
4053+
subclasses = [
4054+
tvm.script.parser.relax.entry.ObjectProxy,
4055+
tvm.script.parser.relax.entry.TensorProxy,
4056+
tvm.script.parser.relax.entry.CallableProxy,
4057+
tvm.script.parser.relax.entry.TupleProxy,
4058+
tvm.script.parser.relax.entry.ShapeProxy,
4059+
]
4060+
4061+
for subclass in subclasses:
4062+
yield make_ir_generator(subclass)
4063+
4064+
40264065
ir_generator = tvm.testing.parameter(
40274066
launch_env_thread,
40284067
opt_gemm_normalize,
@@ -4106,6 +4145,7 @@ def func(A: R.Tensor([10, 20], "float32")):
41064145
return_zero_private,
41074146
return_zero_private_with_attr,
41084147
*op_of_literal(),
4148+
*relax_match_cast_struct_info_proxy(),
41094149
)
41104150

41114151
relax_ir_generator = tvm.testing.parameter(

0 commit comments

Comments
 (0)