Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Operator] Add index_add or index_update to numpy extension #17823

Closed
sxjscience opened this issue Mar 12, 2020 · 3 comments
Closed

[Operator] Add index_add or index_update to numpy extension #17823

sxjscience opened this issue Mar 12, 2020 · 3 comments

Comments

@sxjscience
Copy link
Member

We need the functionality to calculate b = index_add(a, indices, value), which mimics the outcome of a[indices] += value.

This is similar to the tensor_scatter_nd_add in TF: https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_add

Also in JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update

@zheyuye
Copy link
Contributor

zheyuye commented Apr 14, 2020

Besides that, it would be great if npx.scatter_nd is implemented for mx.np.ndarray.

@sxjscience
Copy link
Member Author

@zheyuye @JiangZhaoh @yzhliu @haojin2

To understand the problem, let's consider two use cases. The first one can be solved via gather_nd and the second one cannot be solved via the existing MXNet.

Take elements at specific locations from the input data

out[i, j, ...] = data[i, positions[i, j], ...]

In GluonNLP, the positions are masked locations in the input that we will need to calculate the loss. data is the mapped hidden states of the sequences.

With advanced indexing + imperative API, we can do something like this:

import mxnet as mx
mx.npx.set_np()

data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
out = data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions]
print(out.asnumpy().shape)

In order to make the network hybridizable, we can implement it via gather_nd:

@use_np
def select_vectors_by_position(F, data, positions):
    """Select each batch with the given positions.
    Once advanced indexing can be hybridized, we can revise the implementation.
    out[i, j, :] = data[i, positions[i, j], :]
    Parameters
    ----------
    F
    data
        Input tensor of contextualized token embeddings
        Shape (batch_size, seq_length, units)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_sel_positions).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.
    Returns
    -------
    out
        The selection result.
        Shape (batch_size, num_sel_positions, units)
    """
    # Here, we use gather_nd to select the output from data:
    # Need to compute
    #   out[i, j, :] = in[i, masked_position[i, j], :]
    # Thus, construct a indices with shape [2, batch_size, num_masked_position], where
    #     indices[0, i, j] = i
    #     indices[1, i, j] = masked_position[i, j]
    # Then, out = gather_nd(in, indices)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = F.np.expand_dims(F.npx.arange_like(positions, axis=0),
                                 axis=1).astype(np.int32)
    batch_idx = batch_idx + F.np.zeros_like(positions)
    indices = F.np.stack([batch_idx, positions])
    out = F.npx.gather_nd(data, indices)
    return out

Update elements at specific locations of the input data

For example, if we need some selected locations and will need to replace the elements without own generated element, i.e.,

data[i, positions[i, j], ...] = update_val[i, j, ...]

With advanced indexing + imperative API, we can do something like this:

import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()

data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
print(out.asnumpy().shape)

# or do
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] += update_val
print(out.asnumpy().shape)

However, we cannot surround it with autograd

import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()

data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data.attach_grad()
with mx.autograd.record():
   data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
mx.npx.waitall()

Error message:

MXNetError: Traceback (most recent call last):
  File "src/imperative/imperative.cc", line 203
MXNetError: Check failed: AGInfo: :IsNone(*output): Assigning to NDArrays that are already in a computational graph will cause undefined behavior when evaluating gradients. Please call backward first to clear the graph or do this out side of a record section. Also note that you cannot use inplace operations like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section.

We will need to have a workaround solution to this use case.

@zheyuye zheyuye mentioned this issue May 14, 2020
7 tasks
@zheyuye
Copy link
Contributor

zheyuye commented May 25, 2020

The first use case would be improved though #18319 that insipred by #17327.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants