Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Safe accumulation for computing gradient in Embedding & Take (#18385)
Browse files Browse the repository at this point in the history
* Safe accumulation for computing gradient in Embedding & Take

* Fix bug in TakeGrad: initialize temporal storage for safe_accumulation

* fix lint

* make MXNET_SAFE_ACCUMULATION compatible with Windows

* Increase test coverage: small inputs & SAFE_ACCUMULATION
  • Loading branch information
MoisesHer committed Aug 14, 2020
1 parent a2b400c commit 344587f
Show file tree
Hide file tree
Showing 7 changed files with 453 additions and 117 deletions.
77 changes: 77 additions & 0 deletions 3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,43 @@ __global__ void AddTakeGradKernel(DstPlan dst,
}
}

template<bool clip, int x_bits, typename DstPlan, typename ATypePlan,
typename SrcPlan1, typename SrcPlan2>
__global__ void AddTakeGradKernel(DstPlan dst,
ATypePlan temp,
SrcPlan1 index, SrcPlan2 src,
index_t ymax, index_t xmax, const int K) {
const unsigned x_size = 1 << x_bits;
const int xindex = blockIdx.x * x_size + threadIdx.x;
__shared__ int ptr;
if (xindex < xmax) {
for (unsigned y = 0; y < K; ++y) {
temp.REval(y, xindex) = dst.Eval(y, xindex);
}
}
for (unsigned y = 0; y < ymax; ++y) {
if (threadIdx.x == 0) {
ptr = index.Eval(0, y);
if (clip) {
if (ptr <= 0) ptr = 0;
else if (ptr >= K) ptr = K - 1;
} else {
ptr %= K;
if (ptr < 0) ptr += K;
}
}
__syncthreads();
if (xindex < xmax) {
temp.REval(ptr, xindex) += src.Eval(y, xindex);
}
}
if (xindex < xmax) {
for (unsigned y = 0; y < K; ++y) {
dst.REval(y, xindex) = temp.Eval(y, xindex);
}
}
}

template<int warp_bits, int SZ, typename DType, typename IdxType>
__global__ void AddTakeGradLargeBatchKernel(DType* dst,
const IdxType *sorted, const IdxType *index,
Expand Down Expand Up @@ -733,6 +770,46 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
}

template<bool clip = true, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
Tensor<gpu, 2, AType> temp,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
CHECK_EQ(src.CheckContiguous(), true);
const int kUnitBits = kMemUnitBits + 1;
dim3 dimBlock(1 << kUnitBits);
dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits);

CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGrad: shape mismatch";
CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGrad: shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "AddTakeGrad");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
const int K = dst.shape_[0];

if (clip) {
AddTakeGradKernel<true, kUnitBits>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(temp),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
} else {
AddTakeGradKernel<false, kUnitBits>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(temp),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
}
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
}

template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
Expand Down
26 changes: 26 additions & 0 deletions 3rdparty/mshadow/mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,19 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
* \param index index to take
* \param src source output
*/
template<bool clip = true, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
Tensor<cpu, 2, AType> temp,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation.
dst[index[i]] += src[i]
* \param dst destination
* \temp temporal storage for safe accumulation
* \param index index to take
* \param src source output
*/
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
Expand All @@ -861,6 +874,19 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
* \param index original index of the sorted indices
* \param src source output
*/
template<bool clip = true, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
Tensor<gpu, 2, AType> temp,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation.
dst[index[i]] += src[i]
* \param dst destination
* \temp temporal storage for safe accumulation
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& sorted,
Expand Down
33 changes: 33 additions & 0 deletions 3rdparty/mshadow/mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,39 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
}
}

// safe accumulation
template<bool clip, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
Tensor<cpu, 2, AType> temp,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src) {
const index_t K = dst.shape_[0];
const index_t C = dst.shape_[1];
for (index_t j = 0; j < K; ++j) {
for (index_t i = 0; i < C; ++i) {
temp[j][i] = dst[j][i];
}
}
for (index_t y = 0; y < index.size(0); ++y) {
index_t j = index[y];
if (clip) {
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
} else {
j %= K;
if (j < 0) j += K;
}
for (index_t i = 0; i < C; ++i) {
temp[j][i] += src[y][i];
}
}
for (index_t j = 0; j < K; ++j) {
for (index_t i = 0; i < C; ++i) {
dst[j][i] = temp[j][i];
}
}
}

template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& sorted,
Expand Down
8 changes: 8 additions & 0 deletions 3rdparty/mshadow/mshadow/tensor_gpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
}

template<bool clip, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
Tensor<gpu, 2, AType> temp,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGrad<clip, IndexType, DType>(dst, temp, index, src);
}

template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
Expand Down
84 changes: 52 additions & 32 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -684,27 +684,25 @@ __global__ void EmbeddingFindBounds(const IType *sorted_data,
* \param grad_out output gradient data
* \param embbedding_dim dimension of the dense embedding
* \param vocab_dim maximum number of unique indices in the data array: tokens vocabulary size
* \param nelems_per_load number of elements per each load based on (LType / DType)
* \param req write/add/null
*/
template <typename LType, typename DType, typename IType>
template <typename AType, typename LType, typename DType, typename IType>
__global__ void EmbeddingGradKernel(DType *grad_in,
const IType *original_index,
const IType *index_bounds,
const DType *grad_out,
const index_t embbedding_dim,
const index_t vocab_dim,
const int req) {
const IType *original_index,
const IType *index_bounds,
const DType *grad_out,
const index_t embbedding_dim,
const index_t vocab_dim,
const int nelems_per_load,
const int req) {
extern __shared__ int sharedmem[];
LType* grad_in_row = reinterpret_cast<LType *>(sharedmem);

// LType has to be bigger than DType, guarded in the launcher code
const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) / sizeof(DType) : 1;
AType* grad_in_row = reinterpret_cast<AType *>(sharedmem);
const LType *aligned_grad_out = reinterpret_cast<const LType *>(grad_out);
LType *aligned_grad_in = reinterpret_cast<LType *>(grad_in);
const index_t aligned_emb_dim = embbedding_dim / n_val;
DType *my_grad_in_row = reinterpret_cast<DType *>(&grad_in_row[threadIdx.x]);
LType Lvalue[1];
DType* Dvalues = reinterpret_cast<DType*>(Lvalue);
const index_t aligned_emb_dim = embbedding_dim / nelems_per_load;
LType load_value[1];
DType* data_values = reinterpret_cast<DType*>(load_value);

IType my_row = blockIdx.x;
if (my_row < vocab_dim) {
Expand All @@ -716,29 +714,37 @@ __global__ void EmbeddingGradKernel(DType *grad_in,
for (index_t emb_id=threadIdx.x; emb_id < aligned_emb_dim; emb_id += blockDim.x) {
// Initialize grad_in
if (req == kAddTo) {
grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
*load_value = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
grad_in_row[val_id * blockDim.x + threadIdx.x] = static_cast<AType>(data_values[val_id]);
}
} else {
grad_in_row[threadIdx.x] = 0.0;
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
grad_in_row[val_id * blockDim.x + threadIdx.x] = static_cast<AType>(0.0);
}
}
// Add all rows from grad_out according to indices in data
for (index_t data_idx=lower_bound; data_idx < (lower_bound + nOccurrences); ++data_idx) {
*Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id];
for (index_t val_id = 0; val_id < n_val; val_id++) {
my_grad_in_row[val_id] += Dvalues[val_id];
*load_value = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id];
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
grad_in_row[val_id * blockDim.x + threadIdx.x] += static_cast<AType>(data_values[val_id]);
}
}
// Save results
aligned_grad_in[my_row * aligned_emb_dim + emb_id] = grad_in_row[threadIdx.x];
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
data_values[val_id] = static_cast<DType>(grad_in_row[val_id * blockDim.x + threadIdx.x]);
}
aligned_grad_in[my_row * aligned_emb_dim + emb_id] = *load_value;
}
}
}

template<typename gpu, typename IType, typename DType>
template<typename AType, typename IType, typename DType>
void EmbeddingGradKernelCaller(const OpContext& ctx,
mshadow::Tensor<gpu, 2, DType> grad_in,
const mshadow::Tensor<gpu, 1, IType>& index,
const mshadow::Tensor<gpu, 2, DType> &grad_out,
const std::vector<OpReqType>& req) {
mshadow::Tensor<gpu, 2, DType> grad_in,
const mshadow::Tensor<gpu, 1, IType>& index,
const mshadow::Tensor<gpu, 2, DType> &grad_out,
const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow::expr;

Expand Down Expand Up @@ -792,20 +798,23 @@ void EmbeddingGradKernelCaller(const OpContext& ctx,

// Compute Gradient
int ltype = mxnet::common::cuda::get_load_type(embbedding_dim * sizeof(DType));

MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
int nelems_per_thread = sizeof(LType) / sizeof(DType);
CHECK_LE(sizeof(DType), sizeof(LType));
int nelems_per_load = sizeof(LType) / sizeof(DType);
int threads_block_grad = 32;
int maxThreads = 1024;
while (threads_block_grad < (embbedding_dim/nelems_per_thread) &&
while (threads_block_grad < (embbedding_dim/nelems_per_load) &&
(threads_block_grad < maxThreads))
threads_block_grad += 32;
size_t required_shared = threads_block_grad * sizeof(LType);
size_t required_shared = threads_block_grad * nelems_per_load * sizeof(AType);
dim3 blocks(vocab_dim, 1);
EmbeddingGradKernel<LType><<<blocks, threads_block_grad, required_shared,
EmbeddingGradKernel<AType, LType><<<blocks, threads_block_grad, required_shared,
Stream<gpu>::GetStream(s)>>>(
grad_in.dptr_, original_index.dptr_,
bounds_index.dptr_, grad_out.dptr_,
embbedding_dim, vocab_dim,
nelems_per_load,
req[embedding::kWeight]);
});
}
Expand All @@ -831,9 +840,17 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& oshape = inputs[0].shape_;

Stream<gpu> *s = ctx.get_stream<gpu>();

CHECK_NE(req[embedding::kWeight], kWriteInplace)
<< "Backward of Embedding does not support writing in place.";
MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for EmbeddingOpBackward "
"with float16 inputs. "
"See https://mxnet.apache.org/api/faq/env_var "
"for more details.");
}
MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, {
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
Tensor < gpu, 1, IType > data = inputs[1].get_with_shape<gpu, 1, IType>(
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
Expand All @@ -842,7 +859,10 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
Tensor<gpu, 2, DType> grad_in = outputs[1].get<gpu, 2, DType>(s);

if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) {
EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req);
if (safe_acc)
EmbeddingGradKernelCaller<AType>(ctx, grad_in, data, grad_out, req);
else
EmbeddingGradKernelCaller<DType>(ctx, grad_in, data, grad_out, req);
} else {
LOG(FATAL) << "wrong req";
}
Expand Down
Loading

0 comments on commit 344587f

Please sign in to comment.