Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add CumSum operator to ONNX frontend #7391

Merged
merged 9 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,13 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
Integer axis;
DataType dtype;
Integer exclusive;
TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
TVM_ATTR_FIELD(exclusive)
.describe("The top element is not included")
echuraev marked this conversation as resolved.
Show resolved Hide resolved
.set_default(NullValue<Integer>());
}
};

Expand Down
25 changes: 24 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .. import ty as _ty

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value
from .common import infer_type, get_name


Expand Down Expand Up @@ -1075,6 +1075,28 @@ def _impl_v1(cls, inputs, attr, params):
return _op.shape_of(inputs[0], "int64")


class CumSum(OnnxOpConverter):
"""Operator converter for CumSum."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
dim = inputs[1]

if dim is not None:
dim = int(infer_value(dim, params).asnumpy())

exclusive = attr.get("exclusive", 0)
reverse = attr.get("reverse", 0)

if reverse != 0:
out = _op.reverse(data, axis=dim)
out = _op.cumsum(out, axis=dim, exclusive=exclusive)
return _op.reverse(out, axis=dim)

return _op.cumsum(data, axis=dim, exclusive=exclusive)


class Cast(OnnxOpConverter):
"""Operator converter for Cast."""

Expand Down Expand Up @@ -2736,6 +2758,7 @@ def _get_convert_map(opset):
"Resize": Resize.get_converter(opset),
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compute_scatter_nd(attrs, inputs, output_type):
@_reg.register_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)]
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]


_reg.register_strategy("cumsum", strategy.cumsum_strategy)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def wrap_compute_cumsum(topi_compute):
"""Wrap cumsum topi compute"""

def _compute_cumsum(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype)]
return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]

return _compute_cumsum

Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ def adv_index(inputs):
return _make.adv_index(Tuple(inputs))


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.

Expand All @@ -1339,6 +1339,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.

exclusive : int, optional
If set to 1 will return exclusive sum in which the top element is not
echuraev marked this conversation as resolved.
Show resolved Hide resolved
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.

Returns
-------
result : relay.Expr
Expand Down Expand Up @@ -1368,4 +1374,4 @@ def cumsum(data, axis=None, dtype=None):
cumsum(a, dtype=int32) # dtype should be provided to get the expected results
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype)
return _make.cumsum(data, axis, dtype, exclusive)
10 changes: 9 additions & 1 deletion python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def traverse(op):
return s


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.

Parameters
Expand All @@ -504,6 +504,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.

exclusive : int, optional
If set to 1 will return exclusive sum in which the top element is not
echuraev marked this conversation as resolved.
Show resolved Hide resolved
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.

Returns
-------
result : tvm.te.Tensor
Expand All @@ -514,4 +520,6 @@ def cumsum(data, axis=None, dtype=None):
axis = 0
data = reshape(data, (prod(data.shape),))
axis = get_const_int(axis)
if exclusive is not None and exclusive != 0:
return exclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
21 changes: 18 additions & 3 deletions python/tvm/topi/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .math import cast


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.

Parameters
Expand All @@ -38,6 +38,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.

exclusive : int, optional
If set to 1 will return exclusive sum in which the top element is not
echuraev marked this conversation as resolved.
Show resolved Hide resolved
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.

Returns
-------
result : tvm.te.Tensor
Expand Down Expand Up @@ -75,6 +81,9 @@ def maybe_cast(x):
elif i > axis:
axis_mul_after *= value

if exclusive is None:
exclusive = 0

def gen_ir(data_buf, out_buf):
ib = ir_builder.create()
data_buf = ib.buffer_ptr(data_buf)
Expand All @@ -84,12 +93,18 @@ def gen_ir(data_buf, out_buf):
i = fused // axis_mul_after
j = fused % axis_mul_after
base_idx = i * cumsum_axis_len * axis_mul_after + j
out_buf[base_idx] = maybe_cast(data_buf[base_idx])
if exclusive == 0:
out_buf[base_idx] = maybe_cast(data_buf[base_idx])
else:
out_buf[base_idx] = 0.0
echuraev marked this conversation as resolved.
Show resolved Hide resolved
with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
k = _k + 1
cur_idx = base_idx + k * axis_mul_after
prev_idx = base_idx + (k - 1) * axis_mul_after
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
if exclusive == 0:
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
else:
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[prev_idx])

return ib.get()

Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3705,10 +3705,11 @@ bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

Expr MakeCumsum(Expr data, Integer axis, DataType dtype) {
Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Integer exclusive) {
auto attrs = make_object<CumsumAttrs>();
attrs->dtype = dtype;
attrs->axis = axis;
attrs->exclusive = exclusive;
static const Op& op = Op::Get("cumsum");
return Call(op, {data}, Attrs(attrs), {});
}
Expand Down
64 changes: 64 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,69 @@ def verify_softplus(indata):
verify_softplus(input_data)


def test_cumsum():
def verify_cumsum(indata, axis, exclusive=0, reverse=0):
cumsum_node = onnx.helper.make_node(
"CumSum",
inputs=["X", "axis"],
outputs=["Y"],
)
if exclusive != 0:
exclusive_attr = helper.make_attribute("exclusive", exclusive)
cumsum_node.attribute.append(exclusive_attr)
if reverse != 0:
reverse_attr = helper.make_attribute("reverse", reverse)
cumsum_node.attribute.append(reverse_attr)
nodes = [
make_constant_node("axis", onnx.TensorProto.INT32, [1], [axis]),
cumsum_node,
]

graph = helper.make_graph(
nodes,
"cumsum_test",
inputs=[
helper.make_tensor_value_info("X", TensorProto.FLOAT, list(indata.shape)),
echuraev marked this conversation as resolved.
Show resolved Hide resolved
],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(indata.shape))],
)

model = helper.make_model(graph, producer_name="cumsum_test")

verify_with_ort_with_inputs(model, [indata], dtype="float32", use_vm=True, opset=11)

data = (
np.array(
[
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
9.0,
10.0,
11.0,
12.0,
]
)
.astype(np.float32)
.reshape((3, 4))
)

verify_cumsum(data, 0)
verify_cumsum(data, 1)
verify_cumsum(data, 0, 1, 0)
verify_cumsum(data, 1, 1, 0)
verify_cumsum(data, 0, 0, 1)
verify_cumsum(data, 1, 0, 1)
verify_cumsum(data, 1, 1, 1)
data = np.random.randn(1, 32, 32, 3).astype("float32")
verify_cumsum(data, 1)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4040,3 +4103,4 @@ def verify_softplus(indata):
test_size()
test_maxunpool()
test_softplus()
test_cumsum()