Skip to content

Commit 0c965f4

Browse files
vvchernovValery Chernov
andauthored
[ONNX] Support ScatterElements with reduction (#13894)
* add ScatterElements converter to ONNX front-end * native front-end for ScatterElements was implemented * update ScatterElements in ONNX high-level front-end * update comments * register ScatterElementsAttrs * register scatter elements strategy * implement generic scatter elements in topi * fix min-max redefinition * fix IntImm conversion and update scatter element implementation * fix parallel approach * CI tests for scatter elements were added * small update of description * sphinx issue was fixed * fix scatter deprecation in the CI test * fix * fix scatter version support * fix negative indices * add scatter elements strategy for cuda, gpu * update assert comment, update check of negative indices, hide tests for 18 version * fixes * extend error log for convenient analysis * lint fix * fix * sync dtypes * update cpu tir for scatter elements by scan example * scatter elements was basically implemented for topi/cuda * fix cpu scatter elements * fix gpu scatter elements * fix * small update * transfer indices check out of general loop * trancsfer ranges and strides calculation to gpu device * fixes * fix axis * clean code * fix after review * fix lint --------- Co-authored-by: Valery Chernov <[email protected]>
1 parent 0dd3d4a commit 0c965f4

File tree

14 files changed

+601
-8
lines changed

14 files changed

+601
-8
lines changed

include/tvm/relay/attrs/transform.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
164164
}
165165
};
166166

167+
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
168+
Integer axis;
169+
String reduction;
170+
171+
TVM_DECLARE_ATTRS(ScatterElementsAttrs, "relay.attrs.ScatterElementsAttrs") {
172+
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
173+
TVM_ATTR_FIELD(reduction).set_default("update").describe(
174+
"Reduction mode of the scatter elements, "
175+
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
176+
}
177+
};
178+
167179
struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
168180
String mode;
169181

python/tvm/relay/frontend/onnx.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2848,11 +2848,59 @@ class Scatter(OnnxOpConverter):
28482848
"""Operator converter for Scatter."""
28492849

28502850
@classmethod
2851-
def _impl_v1(cls, inputs, attr, params):
2851+
def _impl_v9(cls, inputs, attr, params):
28522852
axis = attr.get("axis", 0)
28532853
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
28542854

28552855

2856+
class ScatterElements(OnnxOpConverter):
2857+
"""Operator converter for ScatterElements."""
2858+
2859+
@classmethod
2860+
def _args_check(cls, inputs, attr, red_valids=None):
2861+
ret = []
2862+
assert (
2863+
len(inputs) == 3
2864+
), "ScatterElements takes 3 inputs (data, indices, updates), {} given".format(len(inputs))
2865+
assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"]
2866+
2867+
axis = attr.get("axis", 0)
2868+
rank = len(infer_shape(inputs[0]))
2869+
assert rank > 0, "Data rank higher than 0 is expected"
2870+
assert -rank <= axis < rank, "Axis is out of bounds"
2871+
ret.append(axis)
2872+
2873+
if red_valids:
2874+
reduction = attr.get("reduction", None)
2875+
if reduction is None:
2876+
reduction = b"update"
2877+
reduction = reduction.decode("utf-8")
2878+
assert reduction in red_valids, "Only {} modes are supported, but {} is gotten".format(
2879+
red_valids, reduction
2880+
)
2881+
ret.append(reduction)
2882+
2883+
return ret
2884+
2885+
@classmethod
2886+
def _impl_v11(cls, inputs, attr, params):
2887+
axis = cls._args_check(inputs, attr)
2888+
2889+
return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, "update")
2890+
2891+
@classmethod
2892+
def _impl_v16(cls, inputs, attr, params):
2893+
axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul"])
2894+
2895+
return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction)
2896+
2897+
@classmethod
2898+
def _impl_v18(cls, inputs, attr, params):
2899+
axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul", "min", "max"])
2900+
2901+
return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction)
2902+
2903+
28562904
class ScatterND(OnnxOpConverter):
28572905
"""Operator converter for ScatterND."""
28582906

@@ -6588,7 +6636,7 @@ def _get_convert_map(opset):
65886636
"Compress": Compress.get_converter(opset),
65896637
"Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
65906638
"Scatter": Scatter.get_converter(opset),
6591-
"ScatterElements": Scatter.get_converter(opset),
6639+
"ScatterElements": ScatterElements.get_converter(opset),
65926640
"ScatterND": ScatterND.get_converter(opset),
65936641
"EyeLike": EyeLike.get_converter(opset),
65946642
"Squeeze": Squeeze.get_converter(opset),

python/tvm/relay/op/_transform.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,15 @@ def compute_scatter_add(attrs, inputs, output_type):
204204

205205
_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)
206206

207+
# scatter_elements
208+
@_reg.register_compute("scatter_elements")
209+
def compute_scatter_elements(attrs, inputs, output_type):
210+
"""Compute definition of scatter_elements"""
211+
return [topi.scatter_elements(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)]
212+
213+
214+
_reg.register_strategy("scatter_elements", strategy.scatter_elements_strategy)
215+
207216
# scatter_nd
208217
@_reg.register_compute("scatter_nd")
209218
def compute_scatter_nd(attrs, inputs, output_type):
@@ -679,6 +688,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
679688

680689
_reg.register_shape_func("scatter", False, elemwise_shape_func)
681690
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
691+
_reg.register_shape_func("scatter_elements", False, elemwise_shape_func)
682692
_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)
683693

684694

python/tvm/relay/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,11 @@ class ScatterAddAttrs(Attrs):
639639
"""Attributes used in scatter_add operators"""
640640

641641

642+
@tvm._ffi.register_object("relay.attrs.ScatterElementsAttrs")
643+
class ScatterElementsAttrs(Attrs):
644+
"""Attributes used in scatter_elements operators"""
645+
646+
642647
@tvm._ffi.register_object("relay.attrs.ScatterNDAttrs")
643648
class ScatterNDAttrs(Attrs):
644649
"""Attributes used in scatter_nd operators"""

python/tvm/relay/op/strategy/cuda.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,20 @@ def scatter_add_cuda(attrs, inputs, out_type, target):
10991099
return strategy
11001100

11011101

1102+
@scatter_elements_strategy.register(["cuda", "gpu"])
1103+
def scatter_elements_cuda(attrs, inputs, out_type, target):
1104+
"""scatter elements cuda strategy"""
1105+
strategy = _op.OpStrategy()
1106+
strategy.add_implementation(
1107+
wrap_compute_scatter_elements(topi.cuda.scatter_elements),
1108+
wrap_topi_schedule(topi.cuda.schedule_extern),
1109+
name="scatter_elements.cuda",
1110+
plevel=10,
1111+
)
1112+
# TODO(vvchernov): There is possible specification for rank=1 as for scatter
1113+
return strategy
1114+
1115+
11021116
@scatter_nd_strategy.register(["cuda", "gpu"])
11031117
def scatter_nd_cuda(attrs, inputs, out_type, target):
11041118
"""scatter_nd cuda strategy"""

python/tvm/relay/op/strategy/generic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,28 @@ def scatter_add_strategy(attrs, outs, out_type, target):
15801580
return strategy
15811581

15821582

1583+
# scatter_elements
1584+
@override_native_generic_func("scatter_elements_strategy")
1585+
def scatter_elements_strategy(attrs, inputs, out_type, target):
1586+
"""scatter_elements generic strategy"""
1587+
strategy = _op.OpStrategy()
1588+
strategy.add_implementation(
1589+
wrap_compute_scatter_elements(topi.scatter_elements),
1590+
wrap_topi_schedule(topi.generic.schedule_extern),
1591+
name="scatter_elements.generic",
1592+
)
1593+
return strategy
1594+
1595+
1596+
def wrap_compute_scatter_elements(topi_compute):
1597+
"""Wrap scatter_elements topi compute"""
1598+
1599+
def _compute_scatter_elements(attrs, inputs, _):
1600+
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)]
1601+
1602+
return _compute_scatter_elements
1603+
1604+
15831605
# scatter_nd
15841606
@override_native_generic_func("scatter_nd_strategy")
15851607
def scatter_nd_strategy(attrs, inputs, out_type, target):

python/tvm/relay/op/transform.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,41 @@ def scatter_add(data, indices, updates, axis):
403403
return _make.scatter_add(data, indices, updates, axis)
404404

405405

406+
def scatter_elements(data, indices, updates, axis=0, reduction="update"):
407+
"""Scatter elements with updating data by reduction of values in updates
408+
at positions defined by indices.
409+
410+
Parameters
411+
----------
412+
data : relay.Expr
413+
The input data to the operator.
414+
415+
indices : relay.Expr
416+
The index locations to update.
417+
418+
updates : relay.Expr
419+
The values to update.
420+
421+
axis : int
422+
The axis to scatter elements on. It is zero by default.
423+
424+
reduction : string, optional
425+
The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"]
426+
If update, the update values will replace the input data
427+
If add, the update values will be added to the input data
428+
If mul, the update values will be multiply to the input data
429+
If min, there is choice of minimal between the update values and the input data
430+
If max, there is choice of maximal between the update values and the input data
431+
It is "update" by default
432+
433+
Returns
434+
-------
435+
ret : relay.Expr
436+
The computed result.
437+
"""
438+
return _make.scatter_elements(data, indices, updates, axis, reduction)
439+
440+
406441
def scatter_nd(data, indices, updates, mode="update"):
407442
"""Scatter values from an array and update.
408443

python/tvm/topi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .broadcast import *
3939
from .sort import *
4040
from .scatter import *
41+
from .scatter_elements import *
4142
from .sparse_fill_empty_rows import *
4243
from .sparse_reshape import *
4344
from .scatter_add import *

python/tvm/topi/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression
4747
from .rcnn import *
4848
from .scatter import *
49+
from .scatter_elements import *
4950
from .sort import *
5051
from .conv2d_nhwc_tensorcore import *
5152
from .conv3d_ndhwc_tensorcore import *

0 commit comments

Comments
 (0)