Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Need to use safe accumulation for calculating the gradient of Embedding + Take #17703

Open
sxjscience opened this issue Feb 27, 2020 · 3 comments

Comments

@sxjscience
Copy link
Member

Description

Currently, the inner gradient accumulation method in Embedding and take is not based on safe accumulation, which means that we will lose precision in the fp16 case. Here's the example that amplified the issue:

import mxnet as mx
import numpy as np
mx.npx.set_np()

ctx = mx.gpu()
vocab_size = 8
embedding_dim = 1
index_num = 100000

dat = mx.np.random.randint(0, vocab_size, size=(index_num,), ctx=ctx)

for dtype in [np.float16, np.float32]:
    weight = mx.np.random.normal(0, 1, size=(vocab_size, embedding_dim), ctx=ctx, dtype=dtype)

    weight.attach_grad(grad_req='add')
    weight.grad[:] = 1.0
    with mx.autograd.record():
        out = mx.npx.embedding(dat, weight, input_dim=vocab_size, output_dim=embedding_dim) * 0.01
        out.backward()
    print('dtype=', dtype)
    print(weight.grad)

Output:

dtype= <class 'numpy.float16'>
[[32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]] @gpu(0)
dtype= <class 'numpy.float32'>
[[126.748665]
 [127.53883 ]
 [125.30836 ]
 [125.36837 ]
 [126.278564]
 [127.05873 ]
 [124.74824 ]
 [125.018295]] @gpu(0)

Also, the same happens for take

import mxnet as mx
import numpy as np
mx.npx.set_np()

ctx = mx.gpu()
vocab_size = 8
embedding_dim = 1
index_num = 100000

dat = mx.np.random.randint(0, vocab_size, size=(index_num,), ctx=ctx)
weight = mx.np.random.normal(0, 1, size=(vocab_size, embedding_dim), ctx=ctx, dtype=np.float16)

weight.attach_grad(grad_req='add')
weight.grad[:] = 1.0
with mx.autograd.record():
    out = mx.np.take(weight, dat, axis=0) * 0.01
    out.backward()
print(weight.grad)

Output:

dtype= <class 'numpy.float16'>
[[32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]
 [32.]] @gpu(0)
dtype= <class 'numpy.float32'>
[[125.44839 ]
 [126.68865 ]
 [126.62864 ]
 [125.44839 ]
 [127.028725]
 [126.38859 ]
 [125.108315]
 [125.32836 ]] @gpu(0)
@sxjscience sxjscience added the Bug label Feb 27, 2020
@szha
Copy link
Member

szha commented Feb 27, 2020

Is stable sum enabled? It matters when there's a large number of duplicate in the look up index

@sxjscience
Copy link
Member Author

To clarify, it affects both nd and the new numpy (I'm testing with the numpy interface because I'd like to test the take operator).

This mimics the setting that appears in large-batch training of transformers, in which there will be way more tokens than the size of the vocabulary.

@sxjscience
Copy link
Member Author

The simplest fix is to revise the kernel with safe accumulation, which means to cast float16 to float32 before accumulating. Also, I suggest that we should turn on MXNET_SAFE_ACCUMULATION for float16 type in 1.7 (change the default behavior) so that float16 is accumulated via float32.

I think we should use the following approach for summing up a sequence of float16 numbers:

  1. load them as half2,
  2. cast half2 to float2
  3. accumulate the numbers

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants