Skip to content

Commit 4721a46

Browse files
author
Siyuan Feng
committed
[Relax] Add NonZero op
this PR adds the NonZero op to Relax, together with ONNX frontend support
1 parent a5d04a5 commit 4721a46

File tree

7 files changed

+137
-2
lines changed

7 files changed

+137
-2
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2394,6 +2394,14 @@ def _impl_v11(cls, bb, inputs, attr, params):
23942394
return relax.op.unique(data, sorted=sorted, axis=axis)
23952395

23962396

2397+
class NonZero(OnnxOpConverter):
2398+
"""Converts an onnx NonZero node into an equivalent Relax expression."""
2399+
2400+
@classmethod
2401+
def _impl_v9(cls, bb, inputs, attr, params):
2402+
return relax.op.nonzero(inputs[0])
2403+
2404+
23972405
class HardSigmoid(OnnxOpConverter):
23982406
"""Converts an onnx HardSigmoid node into an equivalent Relax expression."""
23992407

@@ -2779,7 +2787,7 @@ def _get_convert_map():
27792787
"Range": Range,
27802788
"OneHot": OneHot,
27812789
"Unique": Unique,
2782-
# "NonZero": NonZero,
2790+
"NonZero": NonZero,
27832791
# "If": If,
27842792
# "LRN": LRN,
27852793
# "MaxRoiPool": MaxRoiPool,

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
from .qdq import dequantize, quantize
100100
from .sampling import multinomial_from_uniform
101101
from .search import argmax, argmin, where
102-
from .set import unique
102+
from .set import nonzero, unique
103103
from .sorting import argsort, sort, topk
104104
from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance
105105
from .ternary import ewise_fma

python/tvm/relax/op/set.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,40 @@ def numpy_unique(
110110
return tvm.nd.array(output_sorted_numpy)
111111
output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis)
112112
return tvm.nd.array(output_numpy)
113+
114+
115+
def nonzero(x: Expr) -> Expr:
116+
"""Find the indices of elements of a tensor that are non-zero.
117+
118+
Parameters
119+
----------
120+
x : relax.Expr
121+
The input data tensor.
122+
123+
Returns
124+
-------
125+
result : relax.Expr
126+
A (n+1)-D tensor containing indices of non-zero elements.
127+
128+
Note
129+
----
130+
This function is equivalent to `onnx.nonzero`.
131+
132+
Examples
133+
--------
134+
135+
.. code-block:: python
136+
137+
x = [[0, 1],
138+
[2, 0]]
139+
nonzero(x) = [[0, 1],
140+
[1, 0]]
141+
142+
"""
143+
return _ffi_api.nonzero(x) # type: ignore
144+
145+
146+
@tvm.register_func("relax.run.nonzero")
147+
def numpy_nonzero(x: tvm.nd.array) -> tvm.nd.array:
148+
np_result = np.atleast_1d(x.numpy()).nonzero()
149+
return tvm.nd.array(np.stack(np_result, axis=0))

src/relax/op/tensor/set.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#include "set.h"
2626

27+
#include <algorithm>
2728
#include <utility>
2829
#include <vector>
2930

@@ -137,5 +138,27 @@ TVM_REGISTER_OP("relax.unique")
137138
.set_attr<FCallPacked>("FCallPacked", "relax.run.unique")
138139
.set_attr<Bool>("FPurity", Bool(true));
139140

141+
/* relax.nonzero */
142+
Expr nonzero(Expr x) {
143+
static const Op& op = Op::Get("relax.nonzero");
144+
return Call(op, {std::move(x)});
145+
}
146+
147+
TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero);
148+
149+
StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) {
150+
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
151+
// Cheat zero dim scalar as 1-dim.
152+
int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1;
153+
return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice);
154+
}
155+
156+
TVM_REGISTER_OP("relax.nonzero")
157+
.set_num_inputs(1)
158+
.add_argument("x", "Tensor", "The input tensor")
159+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNonzero)
160+
.set_attr<FCallPacked>("FCallPacked", "relax.run.nonzero")
161+
.set_attr<Bool>("FPurity", Bool(true));
162+
140163
} // namespace relax
141164
} // namespace tvm

src/relax/op/tensor/set.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,36 @@
2929
namespace tvm {
3030
namespace relax {
3131

32+
/*!
33+
* \brief Find the unique elements in a given tensor.
34+
* In addition, it optionally returns
35+
* - the indices of the input tensor that give the unique values;
36+
* - the indices of the unique tensor that reconstruct the input tensor;
37+
* - the number of times each unique value comes up in the input tensor.
38+
* \param x The input tensor.
39+
* \param sorted Whether to sort the unique elements in ascending order before
40+
* returning as output.
41+
* \param return_index Whether to return an additional tensor with indices for where elements in
42+
* the unique tensor come from the original input.
43+
* \param return_inverse Whether to return an additional tensor with indices for where elements in
44+
* the original input ended up in the returned unique list.
45+
* \param return_counts Whether to return an additional tensor with counts of each unique elements.
46+
* \param axis The dimension to apply unique.
47+
* If not specified, the unique values of the flattened input are returned.
48+
* \return The unique elements of the array. The returned array will be sorted if `sorted` is True.
49+
* Additional return values depend on `return_index`, `return_inverse`, and `return_counts`.
50+
*/
3251
Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse,
3352
PrimValue return_counts, Optional<PrimValue> axis);
53+
54+
/*!
55+
* \brief Returns the indices of the non-zero elements of the input tensor.
56+
* \param x The input tensor.
57+
* \return a list of 1-D tensors containing indices of non-zero elements for each dimension.
58+
* \note This function behaves similarly to numpy.nonzero(), but return a multi-dimensional array
59+
* instead of a tuple of 1-D arrays.
60+
*/
61+
Expr nonzero(Expr x);
3462
} // namespace relax
3563
} // namespace tvm
3664

tests/python/relax/test_frontend_onnx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,6 +2126,11 @@ def test_unique(axis: Optional[int], sorted: int):
21262126
check_correctness(model)
21272127

21282128

2129+
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)])
2130+
def test_nonzero(shape):
2131+
verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64)
2132+
2133+
21292134
@pytest.mark.parametrize("mode", ["DCR", "CRD"])
21302135
def test_depth_to_space(mode: Literal["DCR", "CRD"]):
21312136
in_shape = [1, 8, 2, 3]

tests/python/relax/test_op_set.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,5 +867,39 @@ def test_unique_infer_struct_info_wrong_input_dtype():
867867
bb.normalize(relax.op.unique(x1))
868868

869869

870+
@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)])
871+
def test_nonzero_infer_struct_info(shape):
872+
bb = relax.BlockBuilder()
873+
x0 = relax.Var("x", R.Tensor(shape, "bool"))
874+
875+
_check_inference(
876+
bb,
877+
relax.op.nonzero(x0),
878+
relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"),
879+
)
880+
881+
882+
def test_nonzero_infer_struct_info_ndim_zero():
883+
bb = relax.BlockBuilder()
884+
x = relax.Var("x", R.Tensor((), "bool"))
885+
886+
_check_inference(
887+
bb,
888+
relax.op.nonzero(x),
889+
relax.TensorStructInfo(ndim=2, dtype="int64"),
890+
)
891+
892+
893+
def test_nonzero_infer_struct_info_wrong_input_dtype():
894+
bb = relax.BlockBuilder()
895+
x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4)))
896+
x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32")))
897+
898+
with pytest.raises(TVMError):
899+
bb.normalize(relax.op.nonzero(x0))
900+
with pytest.raises(TVMError):
901+
bb.normalize(relax.op.nonzero(x1))
902+
903+
870904
if __name__ == "__main__":
871905
tvm.testing.main()

0 commit comments

Comments
 (0)