|
19 | 19 | """nn.Tensor operators.""" |
20 | 20 | import inspect |
21 | 21 | import math |
22 | | -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union |
| 22 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union |
23 | 23 |
|
24 | 24 | import numpy as np |
25 | 25 |
|
@@ -1458,6 +1458,45 @@ def _convert(arg): |
1458 | 1458 | ) |
1459 | 1459 |
|
1460 | 1460 |
|
| 1461 | +OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]]) |
| 1462 | + |
| 1463 | + |
| 1464 | +def extern( |
| 1465 | + name: str, |
| 1466 | + args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]], |
| 1467 | + out: OutType, |
| 1468 | +) -> OutType: |
| 1469 | + """Invoke an extern function during runtime. The extern function must be registered with the " |
| 1470 | + TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python).""" |
| 1471 | + from tvm import relax as rx # pylint: disable=import-outside-toplevel |
| 1472 | + |
| 1473 | + def _convert(arg, name: str): |
| 1474 | + if isinstance(arg, Tensor): |
| 1475 | + return arg._expr # pylint: disable=protected-access |
| 1476 | + if isinstance(arg, int): |
| 1477 | + return rx.PrimValue(_tir.IntImm("int64", arg)) |
| 1478 | + if isinstance(arg, float): |
| 1479 | + return rx.PrimValue(_tir.FloatImm("float64", arg)) |
| 1480 | + if isinstance(arg, str): |
| 1481 | + return rx.StringImm(arg) |
| 1482 | + if isinstance(arg, _tir.PrimExpr): |
| 1483 | + return rx.PrimValue(arg) |
| 1484 | + if isinstance(arg, (tuple, list)): |
| 1485 | + return rx.Tuple([_convert(e, f"{name}_{i}") for i, e in enumerate(arg)]) |
| 1486 | + raise TypeError(f"Unsupported input type: {type(arg)}") |
| 1487 | + |
| 1488 | + rx_inputs = _convert(args, "input") |
| 1489 | + rx_outputs_sinfo = _convert(out, "dummy").struct_info |
| 1490 | + return wrap_nested( |
| 1491 | + _op.call_dps_packed( |
| 1492 | + name, |
| 1493 | + args=rx_inputs, |
| 1494 | + out_sinfo=rx_outputs_sinfo, |
| 1495 | + ), |
| 1496 | + name, |
| 1497 | + ) # type: ignore |
| 1498 | + |
| 1499 | + |
1461 | 1500 | def debug_func( |
1462 | 1501 | name: str, |
1463 | 1502 | *args: Union[Tensor, _tir.PrimExpr, int, float, str], |
|
0 commit comments