Skip to content

Commit 3e8560e

Browse files
Siyuan Fengtqchen
authored andcommitted
[Unity][TVMScript] Use explicit R.shape in TVMScript (#13979)
As we've introduced `arg_sinfo` in CallNode, implicit shape constructor is not widely used in TVMScript. This PR removes the implicit shape since it may cause confusion between shape and tuple.
1 parent dc52afb commit 3e8560e

File tree

11 files changed

+93
-43
lines changed

11 files changed

+93
-43
lines changed

python/tvm/relax/utils.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..runtime import String, convert_to_object
2424
from ..tir import PrimExpr
2525
from . import _ffi_api
26-
from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm
26+
from .expr import Expr, Function, PrimValue, StringImm
2727
from .expr import Tuple as rx_Tuple
2828

2929

@@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr:
7474
1. Return the input itself if it's already a `relax.Expr`;
7575
2. Return `relax.PrimValue` if the input is a `PrimExpr`;
7676
3. Return `relax.StringImm` if the input is `tvm.String` or `str`;
77-
4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype;
78-
5. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
77+
4. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
7978
8079
Notes
8180
-----
8281
1. `tvm.tir.StringImm` is not allowed because of ambiguity,
8382
which can be either `relax.StringImm` or `relax.PrimValue`.
84-
2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr`
8583
"""
8684
if isinstance(value, int):
8785
return PrimValue(tir.IntImm("int64", value))
@@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr:
102100
# Case 3
103101
if isinstance(tvm_value, String):
104102
return StringImm(value)
105-
# Case 4 & 5
103+
# Case 4
106104
if isinstance(value, (tuple, list)):
107-
# Note 2
108-
if len(value) == 0:
109-
return rx_Tuple([])
110-
# Case 4
111-
opt_prim_value = [convert_to_object(v) for v in value]
112-
if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]):
113-
return ShapeExpr(value)
114-
# Case 5
115105
# `convert_to_expr` ensures that all elements are `Expr` if no exception raises
116106
return rx_Tuple([convert_to_expr(v) for v in value])
117107
raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`")

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,23 @@ def tuple(*fields: Expr) -> Expr:
329329
return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member
330330

331331

332+
############################### R.shape ################################
333+
334+
335+
def shape(value: List[PrimExpr]) -> Expr:
336+
"""Create a ShapeExpr.
337+
Parameters
338+
----------
339+
value : List[PrimExpr]
340+
The fields of the tuple.
341+
Returns
342+
-------
343+
res : Expr
344+
The result tuple.
345+
"""
346+
return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore
347+
348+
332349
############################### PrimValue ##############################
333350

334351

@@ -407,6 +424,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
407424
"prim_value",
408425
"print",
409426
"reshape",
427+
"shape",
410428
"shape_of",
411429
"str",
412430
"tuple",

python/tvm/script/parser/relax/entry.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tvm.relax import (
2424
Expr,
25+
ShapeExpr,
2526
FuncStructInfo,
2627
Function,
2728
ObjectStructInfo,
@@ -84,24 +85,31 @@ class TensorProxy(StructInfoProxy):
8485

8586
def __init__(
8687
self,
87-
shape: Optional[List[Union[PrimExpr, str]]] = None,
88+
shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
8889
dtype: Optional[str] = None,
8990
ndim: int = -1,
9091
) -> None:
9192
self.shape = shape
93+
if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr):
94+
raise ValueError(
95+
"Only ShapeExpr is allowed as shape expr, but got: "
96+
f"{shape} with type: {type(shape)}"
97+
)
9298
self.dtype = dtype
9399
self.ndim = ndim
94100
super().__init__()
95101

96102
def get_symbolic_vars(self) -> Set[str]:
97-
if self.shape is None:
103+
if self.shape is None or isinstance(self.shape, Expr):
98104
return {}
99105
else:
100106
return {s for s in self.shape if isinstance(s, str) and s.isidentifier()}
101107

102108
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo:
103109
if self.shape is None:
104110
return TensorStructInfo(None, self.dtype, self.ndim)
111+
elif isinstance(self.shape, ShapeExpr):
112+
return TensorStructInfo(self.shape, self.dtype, self.ndim)
105113
else:
106114
if dict_globals is None and any([isinstance(s, str) for s in self.shape]):
107115
raise ValueError(
@@ -113,7 +121,7 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso
113121

114122

115123
def Tensor(
116-
shape: Optional[List[Union[PrimExpr, str]]] = None,
124+
shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
117125
dtype: Optional[str] = None,
118126
ndim: int = -1,
119127
) -> TensorProxy:
@@ -124,8 +132,12 @@ def Tensor(
124132
dtype = shape
125133
shape = None
126134

127-
if shape is not None and not isinstance(shape, (tuple, list)):
128-
raise ValueError(f"shape must be a list or tuple, but got: {shape}")
135+
if (
136+
shape is not None
137+
and not isinstance(shape, (tuple, list))
138+
and not isinstance(shape, ShapeExpr)
139+
):
140+
raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}")
129141
return TensorProxy(shape, dtype, ndim)
130142

131143

src/script/printer/relax/expr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
7171
for (int i = 0, l = n->values.size(); i < l; ++i) {
7272
values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d));
7373
}
74-
return TupleDoc(values_doc);
74+
return Relax(d, "shape")->Call({ListDoc(values_doc)});
7575
});
7676

7777
Optional<ExprDoc> SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) {

src/script/printer/relax/struct_info.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
8989
Array<String> kwargs_keys;
9090
Array<ExprDoc> kwargs_values;
9191
if (n->shape.defined()) {
92-
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
92+
// Need to dig into ShapeExpr to preserve the `R.shape` prefix
93+
if (const auto* shape = n->shape.value().as<relax::ShapeExprNode>()) {
94+
auto shape_expr = GetRef<relax::ShapeExpr>(shape);
95+
ObjectPath shape_p = n_p->Attr("shape")->Attr("values");
96+
Array<ExprDoc> shape_docs;
97+
for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) {
98+
shape_docs.push_back(
99+
PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d));
100+
}
101+
args.push_back(TupleDoc(shape_docs));
102+
} else {
103+
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
104+
}
93105
}
94106
if (!n->IsUnknownDtype()) {
95107
kwargs_keys.push_back("dtype");

tests/python/relax/test_backend_transform_shape_lower.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def main(
167167
n = T.Var("n", "int64")
168168
k = T.Var("k", "int64")
169169
z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
170-
return (k + 1, m, 2)
170+
return R.shape([k + 1, m, 2])
171171

172172
# slot assignment:
173173
# 0: n, 1: m, 2:k, 3: k+1

tests/python/relax/test_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class TestVMBuiltinLower:
109109
@R.function
110110
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
111111
m, n = T.var("int64"), T.var("int64")
112-
alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32")
112+
alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32")
113113
_ = R.call_packed(
114114
"test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))
115115
)

tests/python/relax/test_tvmscript_parser.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
import tvm.script
2323
import tvm.testing
2424
from tvm import IRModule, relax, tir, topi
25-
from tvm.relax import DynTensorType
26-
from tvm.script import ir as I
27-
from tvm.script import relax as R
28-
from tvm.script import tir as T
25+
from tvm.script.parser import ir as I
26+
from tvm.script.parser import relax as R
27+
from tvm.script.parser import tir as T
2928

3029

3130
def _check(
@@ -202,6 +201,23 @@ def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):
202201
_check(foo, bb.get()["foo"])
203202

204203

204+
def test_relax_base_op():
205+
@R.function
206+
def foo(x: R.Tensor((4, 4), "float32")):
207+
alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32")
208+
shape = R.shape_of(alloc)
209+
return shape
210+
211+
x = relax.Var("x", R.Tensor((4, 4), "float32"))
212+
bb = relax.BlockBuilder()
213+
with bb.function("foo", (x,)):
214+
alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0))
215+
shape = bb.emit(relax.op.shape_of(alloc))
216+
bb.emit_func_output(shape)
217+
# todo(yongwww): comment this check because 0 was changed to R.prim_value(0) in the printed IR
218+
# _check(foo, bb.get()["foo"])
219+
220+
205221
def test_symbolic_shape():
206222
@R.function
207223
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
@@ -274,7 +290,7 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")):
274290
y0 = R.match_cast(y, R.Tensor([n], "float32"))
275291
gv = y0
276292
R.output(gv)
277-
return (x0, (m, n * 2))
293+
return (x0, R.shape([m, n * 2]))
278294

279295
x = relax.Var("x", R.Tensor("float32"))
280296
y = relax.Var("y", R.Tensor("float32"))
@@ -314,7 +330,7 @@ def test_tuple_return_2():
314330
def foo(x: R.Tensor("float32", ndim=2)):
315331
n, m = T.var("int64"), T.var("int64")
316332
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
317-
return (x0, (n + 1, m, 1))
333+
return (x0, R.shape([n + 1, m, 1]))
318334

319335
x = relax.Var("x", R.Tensor("float32", ndim=2))
320336
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
@@ -332,7 +348,7 @@ def foo(x: R.Tensor("float32", ndim=2)):
332348
n, m = T.var("int64"), T.var("int64")
333349
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
334350
t0 = (x, x0)
335-
t1 = (x, (n, m), t0)
351+
t1 = (x, R.shape([n, m]), t0)
336352
return t1
337353

338354
x = relax.Var("x", R.Tensor("float32", ndim=2))
@@ -965,9 +981,9 @@ def test_vm_ops():
965981
def foo(x: R.Tensor(("m", "n"), dtype="float32")):
966982
m = T.var("int64")
967983
n = T.var("int64")
968-
storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0)
969-
alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32")
970-
tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0)
984+
storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", runtime_device_index=0)
985+
alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, dtype="float32")
986+
tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0)
971987
_ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n)))
972988
gv = tensor
973989
return alloc, gv

tests/python/relax/test_tvmscript_printer_relax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_tuple_get_item():
292292

293293
def test_shape_expr():
294294
obj = relax.ShapeExpr([1, 2, 3])
295-
_assert_print(obj, "(1, 2, 3)")
295+
_assert_print(obj, "R.shape([1, 2, 3])")
296296

297297

298298
def test_call():
@@ -304,7 +304,7 @@ def test_call():
304304
"""
305305
x = T.Var("x", "int64")
306306
a: R.Tensor((1, x, 3), dtype="float32")
307-
R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,))
307+
R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x]))
308308
""",
309309
)
310310

tests/python/relax/test_vm_build.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TestVMCompileStage2:
8888
def foo(x: R.Tensor(dtype="float32")) -> R.Shape:
8989
n, m = T.var("int64"), T.var("int64")
9090
_ = R.match_cast(x, R.Tensor((n, m), "float32"))
91-
return (n * 2, m * 3)
91+
return R.shape([n * 2, m * 3])
9292

9393
mod = TestVMCompileStage2
9494
target = tvm.target.Target("llvm", host="llvm")
@@ -511,9 +511,9 @@ class TestMemoryAllocStorageTensor:
511511
@R.function
512512
def main(x: R.Tensor((2, 3), dtype="float32")):
513513
storage = R.memory.alloc_storage(
514-
(24,), virtual_device_index=0, storage_scope="global", dtype="float32"
514+
R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32"
515515
)
516-
y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32")
516+
y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32")
517517
_ = copy(x, y)
518518
return y
519519

0 commit comments

Comments
 (0)