Skip to content

Commit

Permalink
[Relay, TOPI] Add numpy style cumsum op (#7334)
Browse files Browse the repository at this point in the history
* Add cumsum relay/topi op

* relay tests working

* add torch frontend converter

* fix for importing detr

* fix bad merge

* begin cuda cumsum

* support non innermost axis

* support rank higher than 3

* making binop parameter

* fix overflow issue in thrust scan

* generic binop parameter working

* relay test working

* fixed for bool input

* remove pytorch change

* fix pylint

* doc update

* Update python/tvm/topi/cumsum.py

Co-authored-by: Tristan Konolige <[email protected]>

* Update tests/python/relay/test_op_level3.py

Co-authored-by: Tristan Konolige <[email protected]>

* add example outputs

* add supported input and output dtype in thrust log

* adding more loop var names

* fix cpplint

* fix missing check for the cuda target in nms thrust sort

* parallelize cpu cumsum

* making binop argument tir function

* update doc for binop

* doc update

Co-authored-by: Tristan Konolige <[email protected]>
  • Loading branch information
masahi and tkonolige authored Jan 26, 2021
1 parent ab8bc0a commit 1e0d356
Show file tree
Hide file tree
Showing 17 changed files with 625 additions and 94 deletions.
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,16 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
}
}; // struct MatrixSetDiagAttrs

/*! \brief Attributes used in cumsum operator */
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
Integer axis;
DataType dtype;
TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
12 changes: 11 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def compute_scatter_add(attrs, inputs, output_type):

_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

# scatter
# scatter_nd
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
Expand All @@ -112,6 +112,16 @@ def compute_scatter_nd(attrs, inputs, output_type):

_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)

# cumsum
@_reg.register_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)]


_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)

#####################
# Shape functions #
#####################
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,3 +996,15 @@ def argwhere_strategy_cuda(attrs, inputs, out_type, target):
name="argwhere.cuda",
)
return strategy


@cumsum_strategy.register(["cuda", "gpu"])
def cumsum_strategy_cuda(attrs, inputs, out_type, target):
"""cumsum cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumsum(topi.cuda.cumsum),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="cumsum.cuda",
)
return strategy
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,3 +1361,24 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
name="threefry_split.generic",
)
return strategy


def wrap_compute_cumsum(topi_compute):
"""Wrap cumsum topi compute"""

def _compute_cumsum(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype)]

return _compute_cumsum


@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
"""cumsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumsum(topi.cumsum),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumsum.generic",
)
return strategy
49 changes: 49 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,3 +1320,52 @@ def adv_index(inputs):
Output tensor.
"""
return _make.adv_index(Tuple(inputs))


def cumsum(data, axis=None, dtype=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis : int, optional
Axis along which the cumulative sum is computed. The default (None) is to compute
the cumsum over the flattened array.
dtype : string, optional
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.
Returns
-------
result : relay.Expr
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
Examples
--------
.. code-block:: python
a = [[1,2,3], [4,5,6]]
cumsum(a) # if axis is not provided, cumsum is done over the flattened input.
-> [ 1, 3, 6, 10, 15, 21]
cumsum(a, dtype="float32")
-> [ 1., 3., 6., 10., 15., 21.]
cumsum(a, axis=0) # sum over rows for each of the 3 columns
-> [[1, 2, 3],
[5, 7, 9]]
cumsum(a, axis=1)
-> [[ 1, 3, 6],
[ 4, 9, 15]]
a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array
cumsum(a, dtype=int32) # dtype should be provided to get the expected results
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype)
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .scatter import *
from .scatter_add import *
from .argwhere import *
from .cumsum import *
from . import generic
from . import nn
from . import x86
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@
from .correlation import *
from .sparse import *
from .argwhere import *
from .scan import *
3 changes: 2 additions & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
tag="fetch_score",
)

if is_thrust_available():
target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
Expand Down
Loading

0 comments on commit 1e0d356

Please sign in to comment.