Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] [Relay] Update unique operator to match ONNX output (1D only) #8099

Merged
merged 26 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
00b861f
Fix topi test case and docs (tvm was returning inverse_indices and cl…
electriclilies May 13, 2021
a4c9242
Passes on CPU, fix unique op test
electriclilies May 14, 2021
e7cd051
more changes
electriclilies May 18, 2021
f5ae586
mtrying to fix optional outputs in onnx importer
electriclilies May 18, 2021
ca4210b
Merge branch 'main' of https://github.com/apache/incubator-tvm into o…
electriclilies May 18, 2021
78d2087
TupleGetItem is being passed a stringgit add python/tvm/relay/fronten…
electriclilies May 18, 2021
13d9af8
Unique is passing onnx unit tests
electriclilies May 21, 2021
d8cbe25
fix indices
electriclilies May 21, 2021
d7853cf
change comment
electriclilies May 21, 2021
0c1544e
fix return of compute unique
electriclilies May 21, 2021
5c8f86c
black
electriclilies May 21, 2021
e37a97b
fix lint
electriclilies May 21, 2021
e70186b
merge and clean up topi test
electriclilies May 21, 2021
7bc6545
Some stray .asnumpy()s got through my merge, fix)
electriclilies May 21, 2021
72a42d8
fix lint
electriclilies May 21, 2021
d6a013a
revert changed .numpys
electriclilies May 21, 2021
e654b27
missed a few
electriclilies May 21, 2021
e27ff5e
fix more .asnumpy
electriclilies May 21, 2021
3d50aef
fix black
electriclilies May 21, 2021
0392655
Fix op level 3 test
electriclilies May 24, 2021
baaa97b
remove prints
electriclilies May 24, 2021
0181f57
Fix pytorch and tf importers
electriclilies May 24, 2021
00f989b
black
electriclilies May 24, 2021
d7676ba
fix lint
electriclilies May 24, 2021
8fcd5a7
fix indentation
electriclilies May 25, 2021
9f42cfe
fix topi test
electriclilies May 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,6 +2953,39 @@ def _impl_v11(cls, inputs, attr, params):
return out


class Unique(OnnxOpConverter):
"""Operator converter for unique"""

@classmethod
def _impl_v11(cls, inputs, attr, params):
if len(inputs) != 1:
raise ValueError("Unique expects 1 input")

data = inputs[0]
axis = attr.get("axis", None)
if axis is None: # If axis is None, flatten the input before calling unique
data = _op.reshape(data, _op.const([-1]))
else:
data_shape = infer_shape(data)
if len(data_shape) != 1:
raise ValueError("TVM only supports 1D Unique operator.")
is_sorted = attr.get("sorted", 1) # sorted is 0 or 1, 1 by default

# ONNX documentation lists return_counts as optional but there is no input to specify
# whether it is returned. Therefore we'll just always return it.
unique = _op.unique(data, is_sorted=(is_sorted == 1), return_counts=True)
num_unique = unique[3]

trim_unique_lambda = lambda input: _op.strided_slice(input, _op.const([0]), num_unique)

unique_vals = trim_unique_lambda(unique[0])
indices = trim_unique_lambda(unique[1])
inverse_indices = unique[2]
counts = trim_unique_lambda(unique[4])
# ONNX unique returns unique, indices, inverse_indices, (optional) counts
return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4)


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3116,6 +3149,7 @@ def _get_convert_map(opset):
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down Expand Up @@ -3304,6 +3338,12 @@ def from_onnx(self, graph, opset, get_output_expr=False):
outputs_num = 1
else:
outputs_num = len(op)

if outputs_num == 1:
op = fold_constant(op)
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))

if outputs_num > 1:
# ONNX supports optional outputs for some nodes.
# This block searches for missing outputs in the ONNX graph
Expand All @@ -3325,8 +3365,8 @@ def from_onnx(self, graph, opset, get_output_expr=False):
# Create the new op with valid outputs
if len(outputs) == 1:
op = outputs[0]
else:
op = _expr.TupleWrapper(outputs, len(outputs))
elif len(outputs) != outputs_num:
op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs))
# Drop invalid outputs for the onnx node
outputs_num = len(outputs)
node_output = [output for output in node_output if output != ""]
Expand All @@ -3335,10 +3375,10 @@ def from_onnx(self, graph, opset, get_output_expr=False):
), "Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name
)

if outputs_num == 1:
self._nodes[node_output[0]] = fold_constant(op)
self._nodes[node_output[0]] = op
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]

Expand Down
10 changes: 6 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,16 +2294,18 @@ def unique(self, inputs, input_types):
logging.warning("TVM always assumes sorted=True for torch.unique")
is_sorted = True
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
[unique, indices, inverse_indices, num_uniq, counts] = _op.unique(
data, is_sorted=is_sorted, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices, counts_sliced)
return (unique_sliced, inverse_indices, counts_sliced)
else:
[unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False)
[unique, indices, inverse_indices, num_uniq] = _op.unique(
data, is_sorted=is_sorted, return_counts=False
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices)
return (unique_sliced, inverse_indices)

# Operator mappings
def create_convert_map(self):
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,19 +2702,21 @@ def _impl(inputs, attr, params, mod):
assert len(inputs) == 1
data = inputs[0]
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
[unique, _, inverse_indices, num_uniq, counts] = _op.unique(
data, is_sorted=False, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices, counts_sliced]),
_expr.Tuple([unique_sliced, inverse_indices, counts_sliced]),
3,
)
[unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False)
[unique, _, inverse_indices, num_uniq] = _op.unique(
data, is_sorted=False, return_counts=False
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices]),
_expr.Tuple([unique_sliced, inverse_indices]),
2,
)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,24 +1045,28 @@ def ensure_tensor(tensor):
def _unique_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
inverse_indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
inverse_indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
return (unique_shape, indices_shape, num_unique_shape)
return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape)


@script
def _unique_with_counts_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
inverse_indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
counts_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
inverse_indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
counts_shape[0] = data_shape[0]
return (unique_shape, indices_shape, num_unique_shape, counts_shape)
return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape, counts_shape)


@_reg.register_shape_func("unique", False)
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,20 +1654,24 @@ def unique(data, is_sorted=True, return_counts=False):
data : relay.Expr
A 1-D tensor of integers.

sorted : bool
is_sorted : bool
Whether to sort the unique elements in ascending order before returning as output.

return_counts : bool
Whether to return the count of each unique element.

Returns
-------
output : relay.Expr
unique : relay.Expr
A 1-D tensor containing the unique elements of the input data tensor.

indices : relay.Expr
A 1-D tensor containing the index of each data element in the output tensor.

inverse_indices : relay.Expr
A 1-D tensor. For each entry in data, it contains the index of that data element in the
unique array.

num_unique : relay.Expr
A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.

Expand All @@ -1694,5 +1698,5 @@ def unique(data, is_sorted=True, return_counts=False):
num_unique = [5]
"""
if return_counts:
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)
83 changes: 48 additions & 35 deletions python/tvm/topi/cuda/unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _calc_num_unique(inc_scan):


def _calc_unique_ir(
data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts
data, argsorted_indices, inc_scan, index_converter, unique_elements, inverse_indices, counts
):
"""Low level IR to calculate unique elements, inverse indices, and counts (optional) of
unique elements of 1-D array.
Expand All @@ -143,7 +143,7 @@ def _calc_unique_ir(
unique_elements : Buffer
A buffer that stores the unique elements.

indices : Buffer
inverse_indices : Buffer
A buffer that stores the the index of each input data element in the unique element array.

counts (optional) : Buffer
Expand All @@ -154,7 +154,7 @@ def _calc_unique_ir(
argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
inc_scan_ptr = ib.buffer_ptr(inc_scan)
unique_elements_ptr = ib.buffer_ptr(unique_elements)
indices_ptr = ib.buffer_ptr(indices)
inverse_indices_ptr = ib.buffer_ptr(inverse_indices)

index_converter_ptr = None
if isinstance(index_converter, tir.Buffer):
Expand All @@ -163,7 +163,7 @@ def _calc_unique_ir(
if isinstance(counts, tir.Buffer):
counts_ptr = ib.buffer_ptr(counts)
# use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1]
unique_seq_indices_ptr = ib.buffer_ptr(indices)
unique_seq_indices_ptr = ib.buffer_ptr(inverse_indices)

batch_size = data.shape[0]
max_threads = _get_max_threads(batch_size)
Expand Down Expand Up @@ -218,7 +218,7 @@ def _calc_unique_ir(
if not index_converter_ptr
else index_converter_ptr[inc_scan_ptr[tid]]
)
indices_ptr[data_idx] = unique_idx
inverse_indices_ptr[data_idx] = unique_idx
with ib.if_scope(tid == 0):
unique_elements_ptr[unique_idx] = data_ptr[data_idx]
with ib.else_scope():
Expand Down Expand Up @@ -293,11 +293,20 @@ def unique(data, is_sorted=True, return_counts=False):

Returns
-------
output : tvm.te.Tensor
A 1-D tensor containing the unique elements of the input data tensor.
unique : tvm.te.Tensor
A 1-D tensor containing the unique elements of the input data tensor. The same size as
the input data. If there are less unique elements than input data, the end of the tensor
is padded with zeros.

indices : tvm.te.Tensor
A 1-D tensor containing the index of each data element in the output tensor.
A 1-D tensor. The same size as output. For each entry in output, it contains
the index of its first occurence in the input data. The end of the tensor is padded
with the length of the input data.

inverse_indices : tvm.te.Tensor
A 1-D tensor. For each entry in data, it contains the index of that data element in the
unique array. (Note that inverse_indices is very similar to indices if output is not
sorted)

num_unique : tvm.te.Tensor
A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.
Expand All @@ -309,20 +318,23 @@ def unique(data, is_sorted=True, return_counts=False):
--------
.. code-block:: python
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, ?, ?, ?]
inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]

[output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
counts = [2, 2, 1, 1, 2, ?, ?, ?]
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, ?, ?, ?]
inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
counts = [2, 2, 1, 1, 2, ?, ?, ?]

[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True)
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [2, 3, 4, 0, 1, ?, ?, ?]
inverse_indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
"""
sorted_data = sort(data)
argsorted_indices = argsort(data, dtype="int32")
Expand Down Expand Up @@ -355,29 +367,29 @@ def unique(data, is_sorted=True, return_counts=False):
out_buffers = [unique_elements_buf, inverse_indices_buf]
out_dtypes = [data.dtype, "int32"]
# prepare inputs and fcompute
# calculate first occurence
first_occurence_buf = tir.decl_buffer(
data.shape, "int32", "first_occurence_buf", data_alignment=8
)
first_occurence = te.extern(
[data.shape],
[argsorted_indices, inc_scan],
lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]),
dtype=["int32"],
in_buffers=[argsorted_indices_buf, inc_scan_buf],
out_buffers=[first_occurence_buf],
name="_calc_first_occurence",
tag="_calc_first_occurence_gpu",
)
if is_sorted:
in_data = [data, argsorted_indices, inc_scan]
in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf]
if return_counts:
fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs)
else:
fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None)
indices = first_occurence
else:
# calculate the index converter if the unique elements should not be sorted
# calculate first occurence
first_occurence_buf = tir.decl_buffer(
data.shape, "int32", "first_occurence_buf", data_alignment=8
)
first_occurence = te.extern(
[data.shape],
[argsorted_indices, inc_scan],
lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]),
dtype=["int32"],
in_buffers=[argsorted_indices_buf, inc_scan_buf],
out_buffers=[first_occurence_buf],
name="_calc_first_occurence",
tag="_calc_first_occurence_gpu",
)
# calculate index converter by sorting unique elements by their first occurence
argsorted_first_occurence = argsort(first_occurence, dtype="int32")
index_converter = argsort(argsorted_first_occurence, dtype="int32")
Expand All @@ -390,6 +402,7 @@ def unique(data, is_sorted=True, return_counts=False):
fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs)
else:
fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None)
indices = sort(first_occurence)
outs = te.extern(
out_data_shape,
in_data,
Expand All @@ -401,5 +414,5 @@ def unique(data, is_sorted=True, return_counts=False):
tag="_calc_unique_gpu",
)
if return_counts:
return [outs[0], outs[1], num_unique_elements, outs[2]]
return [*outs, num_unique_elements]
return [outs[0], indices, outs[1], num_unique_elements, outs[2]]
return [outs[0], indices, outs[1], num_unique_elements]
Loading