Skip to content

Commit 3de5e86

Browse files
authored
[Unity][nn.Module] Support Runtime-Calling Any PackedFunc via op.extern (#16274)
1 parent 1c35c39 commit 3de5e86

File tree

1 file changed

+40
-1
lines changed
  • python/tvm/relax/frontend/nn

1 file changed

+40
-1
lines changed

python/tvm/relax/frontend/nn/op.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""nn.Tensor operators."""
2020
import inspect
2121
import math
22-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
22+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
2323

2424
import numpy as np
2525

@@ -1458,6 +1458,45 @@ def _convert(arg):
14581458
)
14591459

14601460

1461+
OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])
1462+
1463+
1464+
def extern(
1465+
name: str,
1466+
args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
1467+
out: OutType,
1468+
) -> OutType:
1469+
"""Invoke an extern function during runtime. The extern function must be registered with the "
1470+
TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python)."""
1471+
from tvm import relax as rx # pylint: disable=import-outside-toplevel
1472+
1473+
def _convert(arg, name: str):
1474+
if isinstance(arg, Tensor):
1475+
return arg._expr # pylint: disable=protected-access
1476+
if isinstance(arg, int):
1477+
return rx.PrimValue(_tir.IntImm("int64", arg))
1478+
if isinstance(arg, float):
1479+
return rx.PrimValue(_tir.FloatImm("float64", arg))
1480+
if isinstance(arg, str):
1481+
return rx.StringImm(arg)
1482+
if isinstance(arg, _tir.PrimExpr):
1483+
return rx.PrimValue(arg)
1484+
if isinstance(arg, (tuple, list)):
1485+
return rx.Tuple([_convert(e, f"{name}_{i}") for i, e in enumerate(arg)])
1486+
raise TypeError(f"Unsupported input type: {type(arg)}")
1487+
1488+
rx_inputs = _convert(args, "input")
1489+
rx_outputs_sinfo = _convert(out, "dummy").struct_info
1490+
return wrap_nested(
1491+
_op.call_dps_packed(
1492+
name,
1493+
args=rx_inputs,
1494+
out_sinfo=rx_outputs_sinfo,
1495+
),
1496+
name,
1497+
) # type: ignore
1498+
1499+
14611500
def debug_func(
14621501
name: str,
14631502
*args: Union[Tensor, _tir.PrimExpr, int, float, str],

0 commit comments

Comments
 (0)