Skip to content

Commit 48842d7

Browse files
author
Tristan Konolige
authored
[Fix,TOPI] Consolidate generic and x86 scatter nd (#13755)
The generic scatter nd was almost identical to the x86 one and was not tested. They now are one and the same.
1 parent a9c6f13 commit 48842d7

File tree

5 files changed

+26
-155
lines changed

5 files changed

+26
-155
lines changed

python/tvm/relay/op/strategy/x86.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type, target):
768768
"""scatter_nd x86 strategy"""
769769
strategy = _op.OpStrategy()
770770
strategy.add_implementation(
771-
wrap_compute_scatter_nd(topi.x86.scatter_nd),
771+
wrap_compute_scatter_nd(topi.scatter_nd),
772772
wrap_topi_schedule(topi.generic.schedule_extern),
773773
name="scatter_nd.x86",
774774
plevel=10,

python/tvm/topi/scatter.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# under the License.
1717
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
1818
"""Scatter operator"""
19-
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr
2019
from ..te import extern, hybrid
20+
from ..tir import decl_buffer, expr, ir_builder
2121

2222

2323
@hybrid.script
@@ -268,63 +268,58 @@ def scatter_nd(data, indices, updates, mode):
268268
_verify_scatter_nd_inputs(data, indices, updates)
269269

270270
def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
271+
# pylint: disable=invalid-name
271272
ib = ir_builder.create()
272273

273274
data = ib.buffer_ptr(data_ptr)
274275
indices = ib.buffer_ptr(indices_ptr)
275276
updates = ib.buffer_ptr(updates_ptr)
276277
out = ib.buffer_ptr(out_ptr)
277278

278-
fused_shape = 1
279-
for i in data.shape:
280-
fused_shape *= i
281-
with ib.for_range(0, fused_shape) as i:
282-
out[i] = data[i]
283-
284279
# We combine all the indices dimensions but the first one into a single
285280
# dimension so we can iterate it in single loop instead of an arbitrary
286-
# number of loops. We do the same thing for all the data dimensions.
281+
# number of loops. We do the same thing for all the update dimensions.
287282
fused_indices_dimension = 1
288283
for i in indices_ptr.shape[1:]:
289284
fused_indices_dimension *= i
290285

291-
fused_data_dimension = 1
292-
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
293-
fused_data_dimension *= i
286+
fused_updates_dimension = 1
287+
for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
288+
fused_updates_dimension *= i
289+
290+
fused_shape = 1
291+
for i in data_ptr.shape:
292+
fused_shape *= i
293+
294+
with ib.for_range(0, fused_shape) as i:
295+
out[i] = data[i]
294296

295-
with ib.for_range(0, fused_indices_dimension, name="i") as i:
296-
with ib.for_range(0, fused_data_dimension, name="j") as j:
297-
offset = fused_data_dimension
297+
with ib.for_range(0, fused_indices_dimension) as i:
298+
with ib.for_range(0, fused_updates_dimension, kind="parallel") as j:
299+
offset = fused_updates_dimension
298300
index = j # This is x_M, .. x_{N-1} part of the index into out.
299301
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
300302
# of the index into out.
301303
for l in reversed(range(indices_ptr.shape[0].value)):
302304
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
303305
index += offset * indices[i + l * fused_indices_dimension]
304-
ib.emit(
305-
AssertStmt(
306-
indices[i + l * fused_indices_dimension] < shape[l],
307-
StringImm("index out of bounds"),
308-
Evaluate(0),
309-
)
310-
)
311-
offset *= shape[l]
312-
if mode == "add":
313-
out[index] += updates[i * fused_data_dimension + j]
314-
elif mode == "update":
315-
out[index] = updates[i * fused_data_dimension + j]
306+
offset *= data_ptr.shape[l]
307+
if mode == "update":
308+
out[index] = updates[i * fused_updates_dimension + j]
309+
elif mode == "add":
310+
out[index] += updates[i * fused_updates_dimension + j]
316311
else:
317312
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
318313

319314
return ib.get()
320315

321-
out_buf = decl_buffer(shape, data.dtype, "out_buf")
316+
out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
322317
return extern(
323-
[shape],
318+
[data.shape],
324319
[data, indices, updates],
325320
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
326321
dtype=data.dtype,
327322
out_buffers=[out_buf],
328-
name="scatter_nd_generic",
329-
tag="scatter_nd_generic",
323+
name="scatter_nd.generic",
324+
tag="scatter_nd.generic",
330325
)

python/tvm/topi/x86/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from .sparse import *
4141
from .conv2d_alter_op import *
4242
from .dense_alter_op import *
43-
from .scatter import *
4443
from .group_conv2d import *
4544
from .math_alter_op import *
4645
from .concat import *

python/tvm/topi/x86/scatter.py

Lines changed: 0 additions & 119 deletions
This file was deleted.

tests/python/topi/python/test_topi_scatter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ def check_scatter_nd(data, indices, updates, out, mode="add"):
3333
lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
3434
topi.generic.schedule_extern,
3535
),
36-
"cpu": (
37-
lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
38-
topi.generic.schedule_extern,
39-
),
4036
}
4137
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
4238
tvm.topi.testing.compare_numpy_tvm(

0 commit comments

Comments
 (0)