Skip to content

Commit b956917

Browse files
vvchernovValery Chernov
authored andcommitted
[TORCH] scatter_reduce implementation (apache#14018)
* add scatter_reduce to pytorch front-end * test for scatter_reduce was added to pytorch CI * update check * add TODOs waiting for other PRs for development * fix lint * fix min-max reduction for cpu * final clean code --------- Co-authored-by: Valery Chernov <[email protected]>
1 parent 18f2176 commit b956917

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2684,6 +2684,57 @@ def scatter_add(self, inputs, input_types):
26842684
src = inputs[3]
26852685
return _op.scatter_add(data, index, src, axis=axis)
26862686

2687+
def scatter_reduce(self, inputs, input_types):
2688+
assert len(inputs) == 5 or len(inputs) == 6, (
2689+
"scatter_reduce takes 5 or 6 inputs (data, dim, index, src, reduce, include_self), "
2690+
+ "but {} given".format(len(inputs))
2691+
)
2692+
data = inputs[0]
2693+
dim = inputs[1]
2694+
index = inputs[2]
2695+
src = inputs[3]
2696+
reduce = inputs[4]
2697+
if len(inputs) == 6:
2698+
include_self = inputs[5]
2699+
# TODO(vvchernov): support include_self == False
2700+
assert include_self, "include_self=False has not been suppoted for scatter_reduce yet"
2701+
2702+
data_shape = self.infer_shape(inputs[0])
2703+
data_rank = len(data_shape)
2704+
index_shape = self.infer_shape(inputs[2])
2705+
index_rank = len(index_shape)
2706+
src_shape = self.infer_shape(inputs[3])
2707+
src_rank = len(src_shape)
2708+
assert data_rank == index_rank, "Index rank is not the same as data rank"
2709+
assert data_rank == src_rank, "Src rank is not the same as data rank"
2710+
2711+
assert 0 <= dim < data_rank, "Dim is out of bounds"
2712+
2713+
for i in range(data_rank):
2714+
assert index_shape[i] <= src_shape[i], "Index dim size should be less than src one"
2715+
if i != dim:
2716+
assert (
2717+
index_shape[i] <= data_shape[i]
2718+
), "Index dim size should be less than data one"
2719+
2720+
red_valids = ["sum", "prod", "mean", "amax", "amin"]
2721+
assert reduce in red_valids, "Only {} modes are supported, but {} is gotten".format(
2722+
red_valids, reduce
2723+
)
2724+
if reduce == "sum":
2725+
reduce = "add"
2726+
elif reduce == "prod":
2727+
reduce = "mul"
2728+
elif reduce == "amin":
2729+
reduce = "min"
2730+
elif reduce == "amax":
2731+
reduce = "max"
2732+
else: # reduce == "mean"
2733+
# TODO(vvchernov): support mean reduction
2734+
raise NotImplementedError("Mean reduction has not been supported yet!")
2735+
2736+
return _op.scatter_elements(data, index, src, axis=dim, reduction=reduce)
2737+
26872738
def cumsum(self, inputs, input_types):
26882739
data = inputs[0]
26892740
dim = inputs[1]
@@ -3785,6 +3836,8 @@ def create_convert_map(self):
37853836
"aten::nonzero": self.nonzero,
37863837
"aten::nonzero_numpy": self.nonzero_numpy,
37873838
"aten::scatter": self.scatter,
3839+
"aten::scatter_add": self.scatter_add,
3840+
"aten::scatter_reduce": self.scatter_reduce,
37883841
"aten::index_put": self.index_put,
37893842
"aten::scalar_tensor": self.scalar_tensor,
37903843
"aten::__interpolate": self.interpolate,
@@ -3796,7 +3849,6 @@ def create_convert_map(self):
37963849
"aten::new_empty": self.new_empty,
37973850
"aten::randn": self.randn,
37983851
"aten::bincount": self.bincount,
3799-
"aten::scatter_add": self.scatter_add,
38003852
"aten::__not__": self.logical_not,
38013853
"aten::hardswish": self.hard_swish,
38023854
"aten::hardsigmoid": self.hard_sigmoid,

python/tvm/topi/scatter_elements.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
127127
elif reduction == "mul":
128128
out[index2] *= updates[index1]
129129
elif reduction == "min":
130-
tir.min(out[index2], updates[index1])
130+
out[index2] = tir.min(out[index2], updates[index1])
131131
elif reduction == "max":
132-
tir.max(out[index2], updates[index1])
132+
out[index2] = tir.max(out[index2], updates[index1])
133133
else:
134134
raise NotImplementedError(
135135
"scatter_elements reduction not in [update, add, mul, min, max]:",

tests/python/frontend/pytorch/test_forward.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4233,6 +4233,32 @@ def test_fn_scatter_add(dim):
42334233
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)
42344234

42354235

4236+
def test_forward_scatter_reduce():
4237+
"""test_forward_scatter_reduce"""
4238+
# integer cannot be traced
4239+
def test_fn_scatter_reduce(dim, reduce):
4240+
return lambda data, index, src: torch.scatter_reduce(
4241+
data, dim=dim, index=index, src=src, reduce=reduce
4242+
)
4243+
4244+
in_data = torch.rand(3, 5) - 1
4245+
in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
4246+
in_src = torch.rand(2, 5) - 1
4247+
4248+
targets = ["llvm", "cuda"]
4249+
# TODO(vvchernov): support test of mean reduction and include_self=False
4250+
for reduce in ["sum", "prod", "amin", "amax"]:
4251+
verify_trace_model(test_fn_scatter_reduce(0, reduce), [in_data, in_index, in_src], targets)
4252+
4253+
in_data = torch.rand(2, 4) - 1
4254+
in_index = torch.tensor([[2], [3]])
4255+
in_src = torch.rand(2, 1) - 1
4256+
4257+
# TODO(vvchernov): support test of mean reduction and include_self=False
4258+
for reduce in ["sum", "prod", "amin", "amax"]:
4259+
verify_trace_model(test_fn_scatter_reduce(1, reduce), [in_data, in_index, in_src], targets)
4260+
4261+
42364262
def test_forward_index_put():
42374263
"""test_forward_index_put"""
42384264
# torch.index_put for 2D tensor and default accumulate (False)

0 commit comments

Comments
 (0)