|
17 | 17 | # pylint: disable=invalid-name, unused-argument |
18 | 18 | """Backend QNN related feature registration""" |
19 | 19 | import numpy as np |
20 | | -from scipy import special |
| 20 | + |
21 | 21 | import tvm |
22 | 22 | from tvm import relay |
23 | 23 | from tvm._ffi.base import TVMError |
@@ -78,14 +78,29 @@ def hardswish_func(x): |
78 | 78 | register_qnn_unary_op_legalize("qnn.sqrt", np.sqrt) |
79 | 79 | register_qnn_unary_op_legalize("qnn.rsqrt", lambda arr: 1 / np.sqrt(arr)) |
80 | 80 | register_qnn_unary_op_legalize("qnn.exp", np.exp) |
81 | | -register_qnn_unary_op_legalize("qnn.erf", special.erf) |
82 | 81 | register_qnn_unary_op_legalize("qnn.sigmoid", lambda arr: 1 / (1 + np.exp(-arr))) |
83 | 82 | register_qnn_unary_op_legalize("qnn.hardswish", hardswish_func) |
84 | 83 | register_qnn_unary_op_legalize("qnn.tanh", np.tanh) |
85 | 84 | register_qnn_unary_op_legalize("qnn.log", np.log) |
86 | 85 | register_qnn_unary_op_legalize("qnn.abs", np.abs) |
87 | 86 |
|
88 | 87 |
|
| 88 | +@reg.register_qnn_legalize("qnn.erf") |
| 89 | +def _legalize_qnn_erf(attrs, inputs, types): |
| 90 | + from scipy import special # pylint: disable=import-outside-toplevel |
| 91 | + |
| 92 | + return create_integer_lookup_op( |
| 93 | + input_arg=inputs[0], |
| 94 | + floating_point_func=special.erf, |
| 95 | + in_scale=inputs[1], |
| 96 | + in_zero_point=inputs[2], |
| 97 | + out_scale=inputs[3], |
| 98 | + out_zero_point=inputs[4], |
| 99 | + in_dtype=types[0].dtype, |
| 100 | + out_dtype=types[0].dtype, |
| 101 | + ) |
| 102 | + |
| 103 | + |
89 | 104 | # Default to None. If overridden by target, this will not be run. |
90 | 105 | # Generic QNN Conv2D legalization function. |
91 | 106 | @tvm.target.generic_func |
|
0 commit comments