Skip to content

Commit 8faf450

Browse files
Siyuan FengMasterJH5574
andauthored
[TVMScript] B6/B7: Symbolic shape and var shadowing (apache#245)
This PR features symbolic shape support and var shadowing in relax Co-authored-by: Ruihang Lai <[email protected]>
1 parent 165b124 commit 8faf450

File tree

4 files changed

+118
-11
lines changed

4 files changed

+118
-11
lines changed

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from . import _ffi_api
3030
from . import frame
31+
from ..tir import var as _tir_var
3132

3233
############################### Operators ###############################
3334
from tvm.relax.op import shape_of, make_closure, invoke_closure
@@ -44,15 +45,15 @@ class TensorType(Object):
4445

4546

4647
def tensor(
47-
shape: Optional[List[PrimExpr]] = None,
48+
shape: Optional[List[Union[PrimExpr, str]]] = None,
4849
dtype: Optional[str] = None,
4950
ndim: int = -1,
5051
):
5152
"""Helper function for `R.Tensor` in parser
5253
5354
Parameters
5455
----------
55-
shape: Optional[List[PrimExpr]]
56+
shape: Optional[List[Union[PrimExpr, str]]]
5657
The shape of the tensor. It's runtime dependent if `shape` is None.
5758
5859
dtype: Optional[str]
@@ -66,6 +67,12 @@ def tensor(
6667
tensor_type: TensorType
6768
The TensorType that is only used in ir_builder.
6869
"""
70+
if isinstance(shape, (tuple, list)):
71+
shape = list(shape)
72+
for i, s in enumerate(shape):
73+
if isinstance(s, str):
74+
shape[i] = _tir_var("int64", s)
75+
6976
return _ffi_api.Tensor(shape, dtype, ndim) # pylint: disable=no-member # type: ignore
7077

7178

python/tvm/script/parser/core/parser.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,16 @@ def pop_frame():
8080
self.frames.append(VarTableFrame())
8181
return _deferred(pop_frame)
8282

83-
def add(self, var: str, value: Any):
84-
self.frames[-1].add(var)
85-
self.name2value[var].append(value)
83+
def add(self, var: str, value: Any, allow_shadowing: bool = False):
84+
# Skip if the key and value are equal to those in the var_table
85+
if self.name2value[var] and self.name2value[var][-1] == value:
86+
return
87+
if allow_shadowing and var in self.frames[-1].vars:
88+
# Shadowing
89+
self.name2value[var][-1] = value
90+
else:
91+
self.frames[-1].add(var)
92+
self.name2value[var].append(value)
8693

8794
def get(self) -> Dict[str, Any]:
8895
return {key: values[-1] for key, values in self.name2value.items() if values}
@@ -177,13 +184,14 @@ def eval_assign(
177184
target: doc.expr,
178185
source: Any,
179186
bind_value: Callable[["Parser", doc.expr, str, Any], Any],
187+
allow_shadowing: bool = False,
180188
) -> Dict[str, Any]:
181189
if self._duplicate_lhs_check(target) is True:
182190
self.report_error(target, "Duplicate vars assigned.")
183191
var_values = eval_assign(self, target, source)
184192
for k, v in var_values.items():
185193
var = bind_value(self, target, k, v)
186-
self.var_table.add(k, var)
194+
self.var_table.add(k, var, allow_shadowing)
187195
return var_values
188196

189197
def report_error(

python/tvm/script/parser/relax/parser.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,41 @@
1818

1919
from typing import Any
2020

21-
from tvm import relax
21+
from tvm import tir, relax
2222

2323
from ...ir_builder import relax as R
2424
from ...ir_builder.base import IRBuilder
2525
from .._core import Parser, dispatch, doc
2626

2727

2828
def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
29-
# pylint: disable=unused-argument
30-
if isinstance(value, relax.Expr):
29+
var_table = self.var_table.get()
30+
31+
if isinstance(value, tir.Var):
32+
if value.name and var_name != value.name:
33+
self.report_error(
34+
node,
35+
"Cannot define TIR variables with different names. The LHS of binding should has "
36+
"the same name provided in RHS.",
37+
)
38+
if var_name in var_table:
39+
prev_value = var_table[var_name]
40+
if not isinstance(prev_value, tir.Var):
41+
self.report_error(
42+
node,
43+
"Cannot redefine a non-TIR-variable object to a TIR variable. Please define "
44+
"the TIR variable with another name.",
45+
)
46+
if prev_value.dtype != value.dtype:
47+
self.report_error(
48+
node,
49+
"Expected the same dtype for TIR vars "
50+
f"but got {value.dtype} vs {prev_value.dtype}",
51+
)
52+
return prev_value
53+
IRBuilder.name(var_name, value)
54+
return value
55+
elif isinstance(value, relax.Expr):
3156
var = R.emit(value)
3257
IRBuilder.name(var_name, var)
3358
return var
@@ -60,6 +85,11 @@ def visit_arguments(self: Parser, node: doc.arguments) -> None:
6085
self.report_error(arg, "Type annotation is required for function parameters.")
6186
param_type = self.visit_tvm_annotation(arg.annotation)
6287
param = R.arg(arg.arg, param_type)
88+
# Define the symbolic shape var
89+
if param_type.shape is not None:
90+
for shape_expr in param_type.shape:
91+
if isinstance(shape_expr, tir.Var):
92+
self.var_table.add(shape_expr.name, shape_expr)
6393

6494
self.var_table.add(arg.arg, param)
6595

@@ -78,7 +108,7 @@ def visit_assign(self: Parser, node: doc.Assign) -> None:
78108
self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
79109
lhs = node.targets[0]
80110
rhs = self.eval_expr(node.value)
81-
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
111+
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value, allow_shadowing=True)
82112

83113

84114
@dispatch.register(token="relax", type_name="Return")

tests/python/relax/test_tvmscript_parser.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tvm
2121
import tvm.testing
2222

23-
from tvm import relax
23+
from tvm import relax, tir
2424
from tvm import IRModule
2525
from tvm.script.parser import ir as I, tir as T, relax as R
2626

@@ -118,5 +118,67 @@ def foo(x: R.Tensor((4, 4), "float32")):
118118
_check(foo, bb.get()["foo"])
119119

120120

121+
def test_symbolic_shape():
122+
@R.function
123+
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2):
124+
m = T.var("int64", "m")
125+
n = T.var("int64", "n")
126+
gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32")
127+
return gv0
128+
129+
@R.function
130+
def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2):
131+
m = T.var("int64")
132+
n = T.var("int64")
133+
gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32")
134+
return gv0
135+
136+
with pytest.raises(tvm.error.DiagnosticError):
137+
138+
@R.function
139+
def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2):
140+
m = T.var("int64")
141+
n = T.var("int32") # The shape dtype should be int64
142+
gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32")
143+
return gv0
144+
145+
def _expected(name: str):
146+
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
147+
x = relax.Var("x", [m, n], relax.DynTensorType(2, "float32"))
148+
bb = relax.BlockBuilder()
149+
with bb.function(name, (x,)):
150+
out = bb.emit(relax.call_tir("extern_func", x, (m, n), dtype="float32"))
151+
bb.emit_func_output(out)
152+
return bb.get()[name]
153+
154+
_check(foo, _expected("foo"))
155+
_check(bar, _expected("bar"))
156+
157+
158+
def test_shadowing():
159+
@R.function
160+
def foo(x: R.Tensor((4, 4), "float32")):
161+
y = R.add(x, x)
162+
z = R.multiply(x, y)
163+
y = R.add(x, y)
164+
y = z
165+
y = R.multiply(y, x)
166+
z = y
167+
return z
168+
169+
x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32"))
170+
bb = relax.BlockBuilder()
171+
with bb.function("foo", (x,)):
172+
y = bb.emit(relax.op.add(x, x))
173+
z = bb.emit(relax.op.multiply(x, y))
174+
y = bb.emit(relax.op.add(x, y))
175+
y = bb.emit(z)
176+
y = bb.emit(relax.op.multiply(y, x))
177+
z = bb.emit(y)
178+
bb.emit_func_output(z)
179+
180+
_check(foo, bb.get()["foo"])
181+
182+
121183
if __name__ == "__main__":
122184
tvm.testing.main()

0 commit comments

Comments
 (0)