-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Operator] Add index_add
or index_update
to numpy extension
#17823
Comments
Besides that, it would be great if |
@zheyuye @JiangZhaoh @yzhliu @haojin2 To understand the problem, let's consider two use cases. The first one can be solved via Take elements at specific locations from the input data
In GluonNLP, the With advanced indexing + imperative API, we can do something like this: import mxnet as mx
mx.npx.set_np()
data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
out = data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions]
print(out.asnumpy().shape) In order to make the network hybridizable, we can implement it via @use_np
def select_vectors_by_position(F, data, positions):
"""Select each batch with the given positions.
Once advanced indexing can be hybridized, we can revise the implementation.
out[i, j, :] = data[i, positions[i, j], :]
Parameters
----------
F
data
Input tensor of contextualized token embeddings
Shape (batch_size, seq_length, units)
positions
Input tensor of the positions.
Shape (batch_size, num_sel_positions).
For each sample in the batch, the values in this tensor must not exceed
the length of the sequence.
Returns
-------
out
The selection result.
Shape (batch_size, num_sel_positions, units)
"""
# Here, we use gather_nd to select the output from data:
# Need to compute
# out[i, j, :] = in[i, masked_position[i, j], :]
# Thus, construct a indices with shape [2, batch_size, num_masked_position], where
# indices[0, i, j] = i
# indices[1, i, j] = masked_position[i, j]
# Then, out = gather_nd(in, indices)
positions = positions.astype(np.int32)
# batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
batch_idx = F.np.expand_dims(F.npx.arange_like(positions, axis=0),
axis=1).astype(np.int32)
batch_idx = batch_idx + F.np.zeros_like(positions)
indices = F.np.stack([batch_idx, positions])
out = F.npx.gather_nd(data, indices)
return out Update elements at specific locations of the input dataFor example, if we need some selected locations and will need to replace the elements without own generated element, i.e.,
With advanced indexing + imperative API, we can do something like this: import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()
data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
print(out.asnumpy().shape)
# or do
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] += update_val
print(out.asnumpy().shape) However, we cannot surround it with autograd import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()
data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data.attach_grad()
with mx.autograd.record():
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
mx.npx.waitall() Error message:
We will need to have a workaround solution to this use case. |
We need the functionality to calculate
b = index_add(a, indices, value)
, which mimics the outcome ofa[indices] += value
.This is similar to the
tensor_scatter_nd_add
in TF: https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_addAlso in JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update
The text was updated successfully, but these errors were encountered: