Skip to content

Commit 26d9a0b

Browse files
author
Siyuan Feng
committed
[Unity][TVMScript] Update call_packed semantics to support empty sinfo_args
In low-level Relax (after pass `CallTIRewrite`), the `call_packed` nodes do not always have explicit `sinfo_args`. This PR extents the parser to support this case.
1 parent cf14edd commit 26d9a0b

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def output(*vars: Tuple[Var]) -> None:
330330
def call_packed(
331331
func: py_str,
332332
*args: Expr,
333-
sinfo_args: Union[StructInfo, List[StructInfo]],
333+
sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] = None,
334334
**kwargs: Any,
335335
) -> Call:
336336
"""Create a relax Call, which calls a packed function.
@@ -340,7 +340,7 @@ def call_packed(
340340
The name of extern function.
341341
*args : Expr
342342
The arguments.
343-
sinfo_args: Union[StructInfo, List[StructInfo]]
343+
sinfo_args: Optional[Union[StructInfo, List[StructInfo]]]
344344
The list of structure info arguments.
345345
kwargs: Expr
346346
The keyword arguments.
@@ -352,7 +352,7 @@ def call_packed(
352352
"""
353353
op = ExternFunc(func)
354354
if sinfo_args is None:
355-
raise ValueError("R.call_packed is required to have type_args")
355+
sinfo_args = []
356356
if isinstance(sinfo_args, py_tuple): # type: ignore
357357
sinfo_args = list(sinfo_args)
358358
elif not isinstance(sinfo_args, list):

tests/python/relax/test_tvmscript_parser.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,28 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
842842
_check(foo, bb.get()["foo"])
843843

844844

845+
def test_call_packed_without_sinfo_args():
846+
@R.function
847+
def foo(x: R.Object) -> R.Object:
848+
z = R.call_packed("test", x)
849+
return z
850+
851+
x = relax.Var("x", R.Object())
852+
bb = relax.BlockBuilder()
853+
with bb.function("foo", (x)):
854+
z = bb.emit(
855+
relax.Call(
856+
relax.ExternFunc("test"),
857+
(x,),
858+
None,
859+
sinfo_args=[],
860+
)
861+
)
862+
bb.emit_func_output(z)
863+
864+
_check(foo, bb.get()["foo"])
865+
866+
845867
def test_annotation():
846868
@R.function
847869
def foo(

0 commit comments

Comments
 (0)