|
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks |
18 | 18 | """Scatter operator""" |
19 | | -from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr |
20 | 19 | from ..te import extern, hybrid |
| 20 | +from ..tir import decl_buffer, expr, ir_builder |
21 | 21 |
|
22 | 22 |
|
23 | 23 | @hybrid.script |
@@ -268,63 +268,58 @@ def scatter_nd(data, indices, updates, mode): |
268 | 268 | _verify_scatter_nd_inputs(data, indices, updates) |
269 | 269 |
|
270 | 270 | def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): |
| 271 | + # pylint: disable=invalid-name |
271 | 272 | ib = ir_builder.create() |
272 | 273 |
|
273 | 274 | data = ib.buffer_ptr(data_ptr) |
274 | 275 | indices = ib.buffer_ptr(indices_ptr) |
275 | 276 | updates = ib.buffer_ptr(updates_ptr) |
276 | 277 | out = ib.buffer_ptr(out_ptr) |
277 | 278 |
|
278 | | - fused_shape = 1 |
279 | | - for i in data.shape: |
280 | | - fused_shape *= i |
281 | | - with ib.for_range(0, fused_shape) as i: |
282 | | - out[i] = data[i] |
283 | | - |
284 | 279 | # We combine all the indices dimensions but the first one into a single |
285 | 280 | # dimension so we can iterate it in single loop instead of an arbitrary |
286 | | - # number of loops. We do the same thing for all the data dimensions. |
| 281 | + # number of loops. We do the same thing for all the update dimensions. |
287 | 282 | fused_indices_dimension = 1 |
288 | 283 | for i in indices_ptr.shape[1:]: |
289 | 284 | fused_indices_dimension *= i |
290 | 285 |
|
291 | | - fused_data_dimension = 1 |
292 | | - for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: |
293 | | - fused_data_dimension *= i |
| 286 | + fused_updates_dimension = 1 |
| 287 | + for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: |
| 288 | + fused_updates_dimension *= i |
| 289 | + |
| 290 | + fused_shape = 1 |
| 291 | + for i in data_ptr.shape: |
| 292 | + fused_shape *= i |
| 293 | + |
| 294 | + with ib.for_range(0, fused_shape) as i: |
| 295 | + out[i] = data[i] |
294 | 296 |
|
295 | | - with ib.for_range(0, fused_indices_dimension, name="i") as i: |
296 | | - with ib.for_range(0, fused_data_dimension, name="j") as j: |
297 | | - offset = fused_data_dimension |
| 297 | + with ib.for_range(0, fused_indices_dimension) as i: |
| 298 | + with ib.for_range(0, fused_updates_dimension, kind="parallel") as j: |
| 299 | + offset = fused_updates_dimension |
298 | 300 | index = j # This is x_M, .. x_{N-1} part of the index into out. |
299 | 301 | # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part |
300 | 302 | # of the index into out. |
301 | 303 | for l in reversed(range(indices_ptr.shape[0].value)): |
302 | 304 | # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] |
303 | 305 | index += offset * indices[i + l * fused_indices_dimension] |
304 | | - ib.emit( |
305 | | - AssertStmt( |
306 | | - indices[i + l * fused_indices_dimension] < shape[l], |
307 | | - StringImm("index out of bounds"), |
308 | | - Evaluate(0), |
309 | | - ) |
310 | | - ) |
311 | | - offset *= shape[l] |
312 | | - if mode == "add": |
313 | | - out[index] += updates[i * fused_data_dimension + j] |
314 | | - elif mode == "update": |
315 | | - out[index] = updates[i * fused_data_dimension + j] |
| 306 | + offset *= data_ptr.shape[l] |
| 307 | + if mode == "update": |
| 308 | + out[index] = updates[i * fused_updates_dimension + j] |
| 309 | + elif mode == "add": |
| 310 | + out[index] += updates[i * fused_updates_dimension + j] |
316 | 311 | else: |
317 | 312 | raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) |
318 | 313 |
|
319 | 314 | return ib.get() |
320 | 315 |
|
321 | | - out_buf = decl_buffer(shape, data.dtype, "out_buf") |
| 316 | + out_buf = decl_buffer(data.shape, data.dtype, "out_buf") |
322 | 317 | return extern( |
323 | | - [shape], |
| 318 | + [data.shape], |
324 | 319 | [data, indices, updates], |
325 | 320 | lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), |
326 | 321 | dtype=data.dtype, |
327 | 322 | out_buffers=[out_buf], |
328 | | - name="scatter_nd_generic", |
329 | | - tag="scatter_nd_generic", |
| 323 | + name="scatter_nd.generic", |
| 324 | + tag="scatter_nd.generic", |
330 | 325 | ) |
0 commit comments