Skip to content

Commit

Permalink
[ONNX] [Relay] Update unique operator to match ONNX output (1D only) (a…
Browse files Browse the repository at this point in the history
…pache#8099)

* Fix topi test case and docs (tvm was returning inverse_indices and claiming it was indices)

* Passes on CPU, fix unique op test

* more changes

* mtrying to fix optional outputs in onnx importer

* TupleGetItem is being passed a stringgit add python/tvm/relay/frontend/onnx.py debugging print statements

* Unique is passing onnx unit tests

* fix indices

* change comment

* fix return of compute unique

* black

* fix lint

* Some stray .asnumpy()s got through my merge, fix)

* fix lint

* revert changed .numpys

* missed a few

* fix more .asnumpy

* fix black

* Fix op level 3 test

* remove prints

* Fix pytorch and tf importers

* black

* fix lint

* fix indentation

* fix topi test
  • Loading branch information
electriclilies authored and mehrdadh committed Jun 3, 2021
1 parent 8b75e13 commit 51f1c03
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 112 deletions.
48 changes: 44 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,39 @@ def _impl_v11(cls, inputs, attr, params):
return out


class Unique(OnnxOpConverter):
"""Operator converter for unique"""

@classmethod
def _impl_v11(cls, inputs, attr, params):
if len(inputs) != 1:
raise ValueError("Unique expects 1 input")

data = inputs[0]
axis = attr.get("axis", None)
if axis is None: # If axis is None, flatten the input before calling unique
data = _op.reshape(data, _op.const([-1]))
else:
data_shape = infer_shape(data)
if len(data_shape) != 1:
raise ValueError("TVM only supports 1D Unique operator.")
is_sorted = attr.get("sorted", 1) # sorted is 0 or 1, 1 by default

# ONNX documentation lists return_counts as optional but there is no input to specify
# whether it is returned. Therefore we'll just always return it.
unique = _op.unique(data, is_sorted=(is_sorted == 1), return_counts=True)
num_unique = unique[3]

trim_unique_lambda = lambda input: _op.strided_slice(input, _op.const([0]), num_unique)

unique_vals = trim_unique_lambda(unique[0])
indices = trim_unique_lambda(unique[1])
inverse_indices = unique[2]
counts = trim_unique_lambda(unique[4])
# ONNX unique returns unique, indices, inverse_indices, (optional) counts
return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4)


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3118,6 +3151,7 @@ def _get_convert_map(opset):
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down Expand Up @@ -3306,6 +3340,12 @@ def from_onnx(self, graph, opset, get_output_expr=False):
outputs_num = 1
else:
outputs_num = len(op)

if outputs_num == 1:
op = fold_constant(op)
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))

if outputs_num > 1:
# ONNX supports optional outputs for some nodes.
# This block searches for missing outputs in the ONNX graph
Expand All @@ -3327,8 +3367,8 @@ def from_onnx(self, graph, opset, get_output_expr=False):
# Create the new op with valid outputs
if len(outputs) == 1:
op = outputs[0]
else:
op = _expr.TupleWrapper(outputs, len(outputs))
elif len(outputs) != outputs_num:
op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs))
# Drop invalid outputs for the onnx node
outputs_num = len(outputs)
node_output = [output for output in node_output if output != ""]
Expand All @@ -3337,10 +3377,10 @@ def from_onnx(self, graph, opset, get_output_expr=False):
), "Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name
)

if outputs_num == 1:
self._nodes[node_output[0]] = fold_constant(op)
self._nodes[node_output[0]] = op
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]

Expand Down
10 changes: 6 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,16 +2294,18 @@ def unique(self, inputs, input_types):
logging.warning("TVM always assumes sorted=True for torch.unique")
is_sorted = True
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
[unique, indices, inverse_indices, num_uniq, counts] = _op.unique(
data, is_sorted=is_sorted, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices, counts_sliced)
return (unique_sliced, inverse_indices, counts_sliced)
else:
[unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False)
[unique, indices, inverse_indices, num_uniq] = _op.unique(
data, is_sorted=is_sorted, return_counts=False
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices)
return (unique_sliced, inverse_indices)

# Operator mappings
def create_convert_map(self):
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,19 +2702,21 @@ def _impl(inputs, attr, params, mod):
assert len(inputs) == 1
data = inputs[0]
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
[unique, _, inverse_indices, num_uniq, counts] = _op.unique(
data, is_sorted=False, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices, counts_sliced]),
_expr.Tuple([unique_sliced, inverse_indices, counts_sliced]),
3,
)
[unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False)
[unique, _, inverse_indices, num_uniq] = _op.unique(
data, is_sorted=False, return_counts=False
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices]),
_expr.Tuple([unique_sliced, inverse_indices]),
2,
)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,24 +1045,28 @@ def ensure_tensor(tensor):
def _unique_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
inverse_indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
inverse_indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
return (unique_shape, indices_shape, num_unique_shape)
return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape)


@script
def _unique_with_counts_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
inverse_indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
counts_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
inverse_indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
counts_shape[0] = data_shape[0]
return (unique_shape, indices_shape, num_unique_shape, counts_shape)
return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape, counts_shape)


@_reg.register_shape_func("unique", False)
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,20 +1658,24 @@ def unique(data, is_sorted=True, return_counts=False):
data : relay.Expr
A 1-D tensor of integers.
sorted : bool
is_sorted : bool
Whether to sort the unique elements in ascending order before returning as output.
return_counts : bool
Whether to return the count of each unique element.
Returns
-------
output : relay.Expr
unique : relay.Expr
A 1-D tensor containing the unique elements of the input data tensor.
indices : relay.Expr
A 1-D tensor containing the index of each data element in the output tensor.
inverse_indices : relay.Expr
A 1-D tensor. For each entry in data, it contains the index of that data element in the
unique array.
num_unique : relay.Expr
A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.
Expand All @@ -1698,5 +1702,5 @@ def unique(data, is_sorted=True, return_counts=False):
num_unique = [5]
"""
if return_counts:
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)
83 changes: 48 additions & 35 deletions python/tvm/topi/cuda/unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _calc_num_unique(inc_scan):


def _calc_unique_ir(
data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts
data, argsorted_indices, inc_scan, index_converter, unique_elements, inverse_indices, counts
):
"""Low level IR to calculate unique elements, inverse indices, and counts (optional) of
unique elements of 1-D array.
Expand All @@ -143,7 +143,7 @@ def _calc_unique_ir(
unique_elements : Buffer
A buffer that stores the unique elements.
indices : Buffer
inverse_indices : Buffer
A buffer that stores the the index of each input data element in the unique element array.
counts (optional) : Buffer
Expand All @@ -154,7 +154,7 @@ def _calc_unique_ir(
argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
inc_scan_ptr = ib.buffer_ptr(inc_scan)
unique_elements_ptr = ib.buffer_ptr(unique_elements)
indices_ptr = ib.buffer_ptr(indices)
inverse_indices_ptr = ib.buffer_ptr(inverse_indices)

index_converter_ptr = None
if isinstance(index_converter, tir.Buffer):
Expand All @@ -163,7 +163,7 @@ def _calc_unique_ir(
if isinstance(counts, tir.Buffer):
counts_ptr = ib.buffer_ptr(counts)
# use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1]
unique_seq_indices_ptr = ib.buffer_ptr(indices)
unique_seq_indices_ptr = ib.buffer_ptr(inverse_indices)

batch_size = data.shape[0]
max_threads = _get_max_threads(batch_size)
Expand Down Expand Up @@ -218,7 +218,7 @@ def _calc_unique_ir(
if not index_converter_ptr
else index_converter_ptr[inc_scan_ptr[tid]]
)
indices_ptr[data_idx] = unique_idx
inverse_indices_ptr[data_idx] = unique_idx
with ib.if_scope(tid == 0):
unique_elements_ptr[unique_idx] = data_ptr[data_idx]
with ib.else_scope():
Expand Down Expand Up @@ -293,11 +293,20 @@ def unique(data, is_sorted=True, return_counts=False):
Returns
-------
output : tvm.te.Tensor
A 1-D tensor containing the unique elements of the input data tensor.
unique : tvm.te.Tensor
A 1-D tensor containing the unique elements of the input data tensor. The same size as
the input data. If there are less unique elements than input data, the end of the tensor
is padded with zeros.
indices : tvm.te.Tensor
A 1-D tensor containing the index of each data element in the output tensor.
A 1-D tensor. The same size as output. For each entry in output, it contains
the index of its first occurence in the input data. The end of the tensor is padded
with the length of the input data.
inverse_indices : tvm.te.Tensor
A 1-D tensor. For each entry in data, it contains the index of that data element in the
unique array. (Note that inverse_indices is very similar to indices if output is not
sorted)
num_unique : tvm.te.Tensor
A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.
Expand All @@ -309,20 +318,23 @@ def unique(data, is_sorted=True, return_counts=False):
--------
.. code-block:: python
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, ?, ?, ?]
inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
[output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
counts = [2, 2, 1, 1, 2, ?, ?, ?]
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, ?, ?, ?]
inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
counts = [2, 2, 1, 1, 2, ?, ?, ?]
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True)
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [2, 3, 4, 0, 1, ?, ?, ?]
inverse_indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
"""
sorted_data = sort(data)
argsorted_indices = argsort(data, dtype="int32")
Expand Down Expand Up @@ -355,29 +367,29 @@ def unique(data, is_sorted=True, return_counts=False):
out_buffers = [unique_elements_buf, inverse_indices_buf]
out_dtypes = [data.dtype, "int32"]
# prepare inputs and fcompute
# calculate first occurence
first_occurence_buf = tir.decl_buffer(
data.shape, "int32", "first_occurence_buf", data_alignment=8
)
first_occurence = te.extern(
[data.shape],
[argsorted_indices, inc_scan],
lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]),
dtype=["int32"],
in_buffers=[argsorted_indices_buf, inc_scan_buf],
out_buffers=[first_occurence_buf],
name="_calc_first_occurence",
tag="_calc_first_occurence_gpu",
)
if is_sorted:
in_data = [data, argsorted_indices, inc_scan]
in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf]
if return_counts:
fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs)
else:
fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None)
indices = first_occurence
else:
# calculate the index converter if the unique elements should not be sorted
# calculate first occurence
first_occurence_buf = tir.decl_buffer(
data.shape, "int32", "first_occurence_buf", data_alignment=8
)
first_occurence = te.extern(
[data.shape],
[argsorted_indices, inc_scan],
lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]),
dtype=["int32"],
in_buffers=[argsorted_indices_buf, inc_scan_buf],
out_buffers=[first_occurence_buf],
name="_calc_first_occurence",
tag="_calc_first_occurence_gpu",
)
# calculate index converter by sorting unique elements by their first occurence
argsorted_first_occurence = argsort(first_occurence, dtype="int32")
index_converter = argsort(argsorted_first_occurence, dtype="int32")
Expand All @@ -390,6 +402,7 @@ def unique(data, is_sorted=True, return_counts=False):
fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs)
else:
fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None)
indices = sort(first_occurence)
outs = te.extern(
out_data_shape,
in_data,
Expand All @@ -401,5 +414,5 @@ def unique(data, is_sorted=True, return_counts=False):
tag="_calc_unique_gpu",
)
if return_counts:
return [outs[0], outs[1], num_unique_elements, outs[2]]
return [*outs, num_unique_elements]
return [outs[0], indices, outs[1], num_unique_elements, outs[2]]
return [outs[0], indices, outs[1], num_unique_elements]
Loading

0 comments on commit 51f1c03

Please sign in to comment.