From d9d400f98636f7e9a290a4994524d23f965712fb Mon Sep 17 00:00:00 2001 From: MoisesHer <50716238+MoisesHer@users.noreply.github.com> Date: Thu, 13 Aug 2020 22:18:26 -0700 Subject: [PATCH] Safe accumulation for computing gradient in Embedding & Take (#18385) * Safe accumulation for computing gradient in Embedding & Take * Fix bug in TakeGrad: initialize temporal storage for safe_accumulation * fix lint * make MXNET_SAFE_ACCUMULATION compatible with Windows * Increase test coverage: small inputs & SAFE_ACCUMULATION --- .../mshadow/mshadow/cuda/tensor_gpu-inl.cuh | 77 ++++++++ 3rdparty/mshadow/mshadow/tensor.h | 26 +++ 3rdparty/mshadow/mshadow/tensor_cpu-inl.h | 33 ++++ 3rdparty/mshadow/mshadow/tensor_gpu-inl.h | 8 + src/operator/tensor/indexing_op.cu | 84 +++++--- src/operator/tensor/indexing_op.h | 186 +++++++++++++++--- tests/python/gpu/test_operator_gpu.py | 156 ++++++++++----- 7 files changed, 453 insertions(+), 117 deletions(-) diff --git a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh index 02a74b2ad46f..a00aade96835 100644 --- a/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh +++ b/3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh @@ -641,6 +641,43 @@ __global__ void AddTakeGradKernel(DstPlan dst, } } +template +__global__ void AddTakeGradKernel(DstPlan dst, + ATypePlan temp, + SrcPlan1 index, SrcPlan2 src, + index_t ymax, index_t xmax, const int K) { + const unsigned x_size = 1 << x_bits; + const int xindex = blockIdx.x * x_size + threadIdx.x; + __shared__ int ptr; + if (xindex < xmax) { + for (unsigned y = 0; y < K; ++y) { + temp.REval(y, xindex) = dst.Eval(y, xindex); + } + } + for (unsigned y = 0; y < ymax; ++y) { + if (threadIdx.x == 0) { + ptr = index.Eval(0, y); + if (clip) { + if (ptr <= 0) ptr = 0; + else if (ptr >= K) ptr = K - 1; + } else { + ptr %= K; + if (ptr < 0) ptr += K; + } + } + __syncthreads(); + if (xindex < xmax) { + temp.REval(ptr, xindex) += src.Eval(y, xindex); + } + } + if (xindex < xmax) { + for (unsigned y = 0; y < K; ++y) { + dst.REval(y, xindex) = temp.Eval(y, xindex); + } + } +} + template __global__ void AddTakeGradLargeBatchKernel(DType* dst, const IdxType *sorted, const IdxType *index, @@ -733,6 +770,46 @@ inline void AddTakeGrad(Tensor dst, MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel); } +template +inline void AddTakeGrad(Tensor dst, + Tensor temp, + const Tensor& index, + const Tensor &src) { + CHECK_EQ(dst.CheckContiguous(), true); + CHECK_EQ(index.CheckContiguous(), true); + CHECK_EQ(src.CheckContiguous(), true); + const int kUnitBits = kMemUnitBits + 1; + dim3 dimBlock(1 << kUnitBits); + dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits); + + CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGrad: shape mismatch"; + CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGrad: shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "AddTakeGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + const int K = dst.shape_[0]; + + if (clip) { + AddTakeGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(temp), + expr::MakePlan(index), + expr::MakePlan(src), + src.size(0), + src.size(1), K); + } else { + AddTakeGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(temp), + expr::MakePlan(index), + expr::MakePlan(src), + src.size(0), + src.size(1), K); + } + MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel); +} + template inline void AddTakeGradLargeBatch(Tensor dst, const Tensor& sorted, diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index 8dd57e20cd67..c92bf8d076d5 100644 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -848,6 +848,19 @@ inline void AddTakeGrad(Tensor dst, * \param index index to take * \param src source output */ +template +inline void AddTakeGrad(Tensor dst, + Tensor temp, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation. + dst[index[i]] += src[i] + * \param dst destination + * \temp temporal storage for safe accumulation + * \param index index to take + * \param src source output + */ template inline void AddTakeGrad(Tensor dst, const Tensor& index, @@ -861,6 +874,19 @@ inline void AddTakeGrad(Tensor dst, * \param index original index of the sorted indices * \param src source output */ +template +inline void AddTakeGrad(Tensor dst, + Tensor temp, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation. + dst[index[i]] += src[i] + * \param dst destination + * \temp temporal storage for safe accumulation + * \param index index to take + * \param src source output + */ template inline void AddTakeGradLargeBatch(Tensor dst, const Tensor& sorted, diff --git a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h index 2d00220a1142..5f05f0a20225 100644 --- a/3rdparty/mshadow/mshadow/tensor_cpu-inl.h +++ b/3rdparty/mshadow/mshadow/tensor_cpu-inl.h @@ -539,6 +539,39 @@ inline void AddTakeGrad(Tensor dst, } } +// safe accumulation +template +inline void AddTakeGrad(Tensor dst, + Tensor temp, + const Tensor& index, + const Tensor &src) { + const index_t K = dst.shape_[0]; + const index_t C = dst.shape_[1]; + for (index_t j = 0; j < K; ++j) { + for (index_t i = 0; i < C; ++i) { + temp[j][i] = dst[j][i]; + } + } + for (index_t y = 0; y < index.size(0); ++y) { + index_t j = index[y]; + if (clip) { + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + } else { + j %= K; + if (j < 0) j += K; + } + for (index_t i = 0; i < C; ++i) { + temp[j][i] += src[y][i]; + } + } + for (index_t j = 0; j < K; ++j) { + for (index_t i = 0; i < C; ++i) { + dst[j][i] = temp[j][i]; + } + } +} + template inline void AddTakeGradLargeBatch(Tensor dst, const Tensor& sorted, diff --git a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h index e7dde2776f43..3140259d52f0 100644 --- a/3rdparty/mshadow/mshadow/tensor_gpu-inl.h +++ b/3rdparty/mshadow/mshadow/tensor_gpu-inl.h @@ -239,6 +239,14 @@ inline void AddTakeGrad(Tensor dst, cuda::AddTakeGrad(dst, index, src); } +template +inline void AddTakeGrad(Tensor dst, + Tensor temp, + const Tensor& index, + const Tensor &src) { + cuda::AddTakeGrad(dst, temp, index, src); +} + template inline void AddTakeGradLargeBatch(Tensor dst, const Tensor& sorted, diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index e3c8a787d7ad..f9d7a1986bd5 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -684,27 +684,25 @@ __global__ void EmbeddingFindBounds(const IType *sorted_data, * \param grad_out output gradient data * \param embbedding_dim dimension of the dense embedding * \param vocab_dim maximum number of unique indices in the data array: tokens vocabulary size + * \param nelems_per_load number of elements per each load based on (LType / DType) * \param req write/add/null */ -template +template __global__ void EmbeddingGradKernel(DType *grad_in, - const IType *original_index, - const IType *index_bounds, - const DType *grad_out, - const index_t embbedding_dim, - const index_t vocab_dim, - const int req) { + const IType *original_index, + const IType *index_bounds, + const DType *grad_out, + const index_t embbedding_dim, + const index_t vocab_dim, + const int nelems_per_load, + const int req) { extern __shared__ int sharedmem[]; - LType* grad_in_row = reinterpret_cast(sharedmem); - - // LType has to be bigger than DType, guarded in the launcher code - const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) / sizeof(DType) : 1; + AType* grad_in_row = reinterpret_cast(sharedmem); const LType *aligned_grad_out = reinterpret_cast(grad_out); LType *aligned_grad_in = reinterpret_cast(grad_in); - const index_t aligned_emb_dim = embbedding_dim / n_val; - DType *my_grad_in_row = reinterpret_cast(&grad_in_row[threadIdx.x]); - LType Lvalue[1]; - DType* Dvalues = reinterpret_cast(Lvalue); + const index_t aligned_emb_dim = embbedding_dim / nelems_per_load; + LType load_value[1]; + DType* data_values = reinterpret_cast(load_value); IType my_row = blockIdx.x; if (my_row < vocab_dim) { @@ -716,29 +714,37 @@ __global__ void EmbeddingGradKernel(DType *grad_in, for (index_t emb_id=threadIdx.x; emb_id < aligned_emb_dim; emb_id += blockDim.x) { // Initialize grad_in if (req == kAddTo) { - grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim + emb_id]; + *load_value = aligned_grad_in[my_row * aligned_emb_dim + emb_id]; + for (index_t val_id = 0; val_id < nelems_per_load; val_id++) { + grad_in_row[val_id * blockDim.x + threadIdx.x] = static_cast(data_values[val_id]); + } } else { - grad_in_row[threadIdx.x] = 0.0; + for (index_t val_id = 0; val_id < nelems_per_load; val_id++) { + grad_in_row[val_id * blockDim.x + threadIdx.x] = static_cast(0.0); + } } // Add all rows from grad_out according to indices in data for (index_t data_idx=lower_bound; data_idx < (lower_bound + nOccurrences); ++data_idx) { - *Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id]; - for (index_t val_id = 0; val_id < n_val; val_id++) { - my_grad_in_row[val_id] += Dvalues[val_id]; + *load_value = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id]; + for (index_t val_id = 0; val_id < nelems_per_load; val_id++) { + grad_in_row[val_id * blockDim.x + threadIdx.x] += static_cast(data_values[val_id]); } } // Save results - aligned_grad_in[my_row * aligned_emb_dim + emb_id] = grad_in_row[threadIdx.x]; + for (index_t val_id = 0; val_id < nelems_per_load; val_id++) { + data_values[val_id] = static_cast(grad_in_row[val_id * blockDim.x + threadIdx.x]); + } + aligned_grad_in[my_row * aligned_emb_dim + emb_id] = *load_value; } } } -template +template void EmbeddingGradKernelCaller(const OpContext& ctx, - mshadow::Tensor grad_in, - const mshadow::Tensor& index, - const mshadow::Tensor &grad_out, - const std::vector& req) { + mshadow::Tensor grad_in, + const mshadow::Tensor& index, + const mshadow::Tensor &grad_out, + const std::vector& req) { using namespace mxnet_op; using namespace mshadow::expr; @@ -792,20 +798,23 @@ void EmbeddingGradKernelCaller(const OpContext& ctx, // Compute Gradient int ltype = mxnet::common::cuda::get_load_type(embbedding_dim * sizeof(DType)); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - int nelems_per_thread = sizeof(LType) / sizeof(DType); + CHECK_LE(sizeof(DType), sizeof(LType)); + int nelems_per_load = sizeof(LType) / sizeof(DType); int threads_block_grad = 32; int maxThreads = 1024; - while (threads_block_grad < (embbedding_dim/nelems_per_thread) && + while (threads_block_grad < (embbedding_dim/nelems_per_load) && (threads_block_grad < maxThreads)) threads_block_grad += 32; - size_t required_shared = threads_block_grad * sizeof(LType); + size_t required_shared = threads_block_grad * nelems_per_load * sizeof(AType); dim3 blocks(vocab_dim, 1); - EmbeddingGradKernel<<<<::GetStream(s)>>>( grad_in.dptr_, original_index.dptr_, bounds_index.dptr_, grad_out.dptr_, embbedding_dim, vocab_dim, + nelems_per_load, req[embedding::kWeight]); }); } @@ -831,9 +840,17 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& oshape = inputs[0].shape_; Stream *s = ctx.get_stream(); + CHECK_NE(req[embedding::kWeight], kWriteInplace) << "Backward of Embedding does not support writing in place."; - MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, { + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); + if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) { + common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for EmbeddingOpBackward " + "with float16 inputs. " + "See https://mxnet.apache.org/api/faq/env_var " + "for more details."); + } + MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, { MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { Tensor < gpu, 1, IType > data = inputs[1].get_with_shape( Shape1(ishape.ProdShape(0, ishape.ndim())), s); @@ -842,7 +859,10 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, Tensor grad_in = outputs[1].get(s); if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) { - EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req); + if (safe_acc) + EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req); + else + EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req); } else { LOG(FATAL) << "wrong req"; } diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5454900968ec..7f0f0fa0ad4c 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -400,20 +400,38 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& oshape = inputs[0].shape_; Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, { + + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); + if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) { + common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for EmbeddingOpBackward " + "with float16 inputs. " + "See https://mxnet.apache.org/api/faq/env_var " + "for more details."); + } + MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, { MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { Tensor < xpu, 1, IType > data = inputs[1].get_with_shape( Shape1(ishape.ProdShape(0, ishape.ndim())), s); Tensor grad_out = inputs[0].get_with_shape( - Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); + Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); Tensor grad_in = outputs[1].get(s); - if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) { if (req[embedding::kWeight] == kWriteTo) { grad_in = scalar(0.0f); } - AddTakeGrad(grad_in, data, grad_out); + if (safe_acc) { + // Temporary storage for safe accumulation + size_t temp_space_size = grad_in.size(0) * grad_in.size(1) * sizeof(AType); + Tensor temp_space = + ctx.requested[embedding::kTempSpace].get_space_typed( + Shape1(temp_space_size), s); + Tensor temp_grad_in(reinterpret_cast(temp_space.dptr_), + grad_in.shape_, s); + AddTakeGrad(grad_in, temp_grad_in, data, grad_out); + } else { + AddTakeGrad(grad_in, data, grad_out); + } } else { LOG(FATAL) << "wrong req"; } @@ -696,7 +714,48 @@ struct TakeGradGeneralKernel { } }; -template +struct TakeGradGeneralKernelSafeAccumulation { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param arr_grad ptr to in_grad + * \param temp ptr to temporal space to perform accumulation + * \param ograd ptr to out_grad + * \param src_indptr ptr to indptr to src indices + * \param original_idx ptr to original indices of the inputs + * \param in_strides strides of inputs + * \param out_strides strides of outputs + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template + MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, AType* temp, + const DType* ograd, + const IType* src_indptr, const IType* original_idx, + mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis, const int K) { + const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; + const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; + const int in_mid_index = in_rest_index / in_strides[axis]; + const int in_tail_index = (axis == in_ndims - 1) ? + 0 : (in_rest_index % in_strides[axis]); + temp[tid] = static_cast(arr_grad[tid]); + for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { + int out_mid_index = original_idx[i]; + out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index; + int target = in_tail_index + out_mid_index * in_strides[axis]; + target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; + temp[tid] += ograd[target]; + } + arr_grad[tid] = temp[tid]; + } +}; + +template void TakeOpBackwardImpl(mshadow::Stream* s, const OpContext& ctx, const TBlob& arr, @@ -715,14 +774,23 @@ void TakeOpBackwardImpl(mshadow::Stream* s, size_t temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); size_t original_idx_bytes = idxshape.Size() * sizeof(int); size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); - size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + size_t temp_accumulation_arrgrad_bytes = 0; + if (safe_acc) { + temp_accumulation_arrgrad_bytes = arr.Size() * sizeof(AType); + } + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + + temp_storage_bytes + temp_accumulation_arrgrad_bytes; Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); - int* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); - int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); - src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + AType* temp_accum_arrgrad_ptr = reinterpret_cast(workspace.dptr_); + int* sorted_idx_ptr = reinterpret_cast(workspace.dptr_ + temp_accumulation_arrgrad_bytes); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + temp_accumulation_arrgrad_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes + + temp_accumulation_arrgrad_bytes); Tensor temp_storage( - workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes, Shape1(temp_storage_bytes), s); + workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes + temp_accumulation_arrgrad_bytes, + Shape1(temp_storage_bytes), s); // Reset indptr to zero Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); // Fill original_idx @@ -759,16 +827,23 @@ void TakeOpBackwardImpl(mshadow::Stream* s, out_strides[i] = stride; } MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { - Kernel::Launch( - s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, - original_idx_ptr, in_strides, out_strides, - arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast(arrshape[axis])); + if (safe_acc) { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), temp_accum_arrgrad_ptr, ograd.dptr(), + src_indptr_ptr, original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast(arrshape[axis])); + } else { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), + src_indptr_ptr, original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast(arrshape[axis])); + } }); }); } #ifdef __CUDACC__ -template +template void TakeOpBackwardImpl(mshadow::Stream* s, const OpContext& ctx, const TBlob& arr, @@ -808,13 +883,23 @@ void TakeOpBackwardImpl(mshadow::Stream* s, temp_storage_bytes = max(temp_storage_bytes, histo_temp_storage_bytes); size_t original_idx_bytes = idxshape.Size() * sizeof(int); size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); - size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + size_t temp_accumulation_igrad_bytes = 0; + if (safe_acc) { + temp_accumulation_igrad_bytes = arr.Size() * sizeof(AType); + } + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + + temp_storage_bytes + temp_accumulation_igrad_bytes; Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); - sorted_idx_ptr = reinterpret_cast(workspace.dptr_); - int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); - src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); - temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; + AType* temp_accum_igrad_ptr = reinterpret_cast(workspace.dptr_); + sorted_idx_ptr = reinterpret_cast(workspace.dptr_ + temp_accumulation_igrad_bytes); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + temp_accumulation_igrad_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes + + temp_accumulation_igrad_bytes); + temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes + + temp_accumulation_igrad_bytes; + // Reset indptr to zero Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); // Fill original_idx @@ -863,10 +948,19 @@ void TakeOpBackwardImpl(mshadow::Stream* s, out_strides[i] = stride; } MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { - Kernel::Launch( - s, arrshape.Size(), arr.dptr(), ograd.dptr(), - src_indptr_ptr, original_idx_ptr, in_strides, out_strides, - arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, static_cast(arrshape[axis])); + if (safe_acc) { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), temp_accum_igrad_ptr, ograd.dptr(), + src_indptr_ptr, original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, + static_cast(arrshape[axis])); + } else { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), + src_indptr_ptr, original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis, + static_cast(arrshape[axis])); + } }); }); } @@ -891,7 +985,14 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, // grad_in is the gradient of the inputs in the feed-forward Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); + if (!safe_acc && outputs[0].type_flag_ == mshadow::kFloat16) { + common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for TakeOpBackward " + "with float16 inputs. " + "See https://mxnet.apache.org/api/faq/env_var " + "for more details."); + } + MXNET_REAL_ACC_TYPE_SWITCH(outputs[0].type_flag_, DType, AType, { MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type // inputs are specified in the .cc file, which are the gradients from // the upper layer and the input index @@ -925,10 +1026,25 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, if (req[take_::kArr] == kWriteTo) { grad_in = scalar(0.0f); } - if (param.mode == take_::kClip) { - AddTakeGrad(grad_in, idx, grad_out); + if (safe_acc) { + // Temporary storage for safe accumulation + size_t temp_space_size = grad_in.size(0) * grad_in.size(1) * sizeof(AType); + Tensor temp_space = + ctx.requested[take_::kTempSpace].get_space_typed( + Shape1(temp_space_size), s); + Tensor temp_grad_in(reinterpret_cast(temp_space.dptr_), + grad_in.shape_, s); + if (param.mode == take_::kClip) { + AddTakeGrad(grad_in, temp_grad_in, idx, grad_out); + } else { + AddTakeGrad(grad_in, temp_grad_in, idx, grad_out); + } } else { - AddTakeGrad(grad_in, idx, grad_out); + if (param.mode == take_::kClip) { + AddTakeGrad(grad_in, idx, grad_out); + } else { + AddTakeGrad(grad_in, idx, grad_out); + } } } else { LOG(FATAL) << "wrong req"; @@ -939,10 +1055,18 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, const TBlob& arr = outputs[0]; const TBlob& ograd = inputs[0]; - if (param.mode == take_::kClip) { - TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + if (safe_acc) { + if (param.mode == take_::kClip) { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } else { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } } else { - TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + if (param.mode == take_::kClip) { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } else { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } } } }); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 84cfe9cfa35d..519c02f141e9 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1522,19 +1522,23 @@ def test_lrn(): reason="Testing with naive engine consistently triggers illegal memory access. Tracked in #17713") def test_embedding_with_type(): def test_embedding_helper(data_types, weight_types, low_pad, high_pad): - NVD = [[20, 10, 20], [200, 10, 300]] - for N, V, D in NVD: - sym = mx.sym.Embedding(name='embedding', input_dim=V, output_dim=D) - ctx_list = [] - for data_type in data_types: - for weight_type in weight_types: - ctx_list.append({'ctx': mx.gpu(0), 'embedding_data': (N,), - 'type_dict': {'embedding_data': data_type, 'embedding_weight': weight_type}}) - ctx_list.append({'ctx': mx.cpu(0), 'embedding_data': (N,), - 'type_dict': {'embedding_data': data_type, 'embedding_weight': weight_type}}) - arg_params = {'embedding_data': np.random.randint(low=-low_pad, high=V+high_pad, size=(N,))} - check_consistency(sym, ctx_list, grad_req={'embedding_data': 'null','embedding_weight': 'write'}, - arg_params=arg_params, scale=0.1) + NVD = [[20, 10, 20], [200, 10, 300], [10000, 4, 20]] + for safe_accumulation in ['0', '1', None]: + for N, V, D in NVD: + with environment('MXNET_SAFE_ACCUMULATION', safe_accumulation): + if N > 1000 and safe_accumulation != '1': + break + sym = mx.sym.Embedding(name='embedding', input_dim=V, output_dim=D) + ctx_list = [] + for data_type in data_types: + for weight_type in weight_types: + ctx_list.append({'ctx': mx.gpu(0), 'embedding_data': (N,), + 'type_dict': {'embedding_data': data_type, 'embedding_weight': weight_type}}) + ctx_list.append({'ctx': mx.cpu(0), 'embedding_data': (N,), + 'type_dict': {'embedding_data': data_type, 'embedding_weight': weight_type}}) + arg_params = {'embedding_data': np.random.randint(low=-low_pad, high=V+high_pad, size=(N,))} + check_consistency(sym, ctx_list, grad_req={'embedding_data': 'null','embedding_weight': 'write'}, + arg_params=arg_params, scale=0.1) data_types = [np.float16, np.float32, np.float64, np.int32] weight_types = [np.float16, np.float32, np.float64] @@ -1547,47 +1551,91 @@ def test_embedding_helper(data_types, weight_types, low_pad, high_pad): @with_seed() def test_take_with_type(): sym = mx.sym.take(name='take') - for data_ndim in range(2, 5): - for idx_ndim in range(1, 4): - data_shape = () - for _ in range(data_ndim): - data_shape += (np.random.randint(low=3, high=6), ) - idx_shape = () - for _ in range(idx_ndim): - idx_shape += (np.random.randint(low=3, high=5), ) - ctx_list = [{'ctx': mx.gpu(0), 'take_indices': idx_shape, - 'take_a': data_shape, - 'type_dict': {'take_indices': np.float64, - 'take_a': np.float64}}, - {'ctx': mx.gpu(0), 'take_indices': idx_shape, - 'take_a': data_shape, - 'type_dict': {'take_indices': np.float32, - 'take_a': np.float32}}, - {'ctx': mx.gpu(0), 'take_indices': idx_shape, - 'take_a': data_shape, - 'type_dict': {'take_indices': np.float16, - 'take_a': np.float16}}, - {'ctx': mx.cpu(0), 'take_indices': idx_shape, - 'take_a': data_shape, - 'type_dict': {'take_indices': np.float64, - 'take_a': np.float64}}, - {'ctx': mx.cpu(0), 'take_indices': idx_shape, - 'take_a': data_shape, - 'type_dict': {'take_indices': np.float32, - 'take_a': np.float32}}, - {'ctx': mx.cpu(0), 'take_indices': idx_shape, - 'take_a': data_shape, - 'type_dict': {'take_indices': np.float16, - 'take_a': np.float16}}] - arg_params = {'take_indices': np.random.randint(low=0, - high=data_shape[0], - size=idx_shape), - 'take_a': np.random.normal(size=data_shape)} - check_consistency(sym, ctx_list, - grad_req={'take_indices': 'null', - 'take_a': 'write'}, - arg_params=arg_params) - + for safe_accumulation in ['0', '1', None]: + for data_ndim in range(2, 5): + for idx_ndim in range(1, 4): + data_shape = () + for _ in range(data_ndim): + data_shape += (np.random.randint(low=3, high=6), ) + idx_shape = () + for _ in range(idx_ndim): + idx_shape += (np.random.randint(low=3, high=5), ) + ctx_list = [{'ctx': mx.gpu(0), 'take_indices': idx_shape, + 'take_a': data_shape, + 'type_dict': {'take_indices': np.float64, + 'take_a': np.float64}}, + {'ctx': mx.gpu(0), 'take_indices': idx_shape, + 'take_a': data_shape, + 'type_dict': {'take_indices': np.float32, + 'take_a': np.float32}}, + {'ctx': mx.gpu(0), 'take_indices': idx_shape, + 'take_a': data_shape, + 'type_dict': {'take_indices': np.float16, + 'take_a': np.float16}}, + {'ctx': mx.cpu(0), 'take_indices': idx_shape, + 'take_a': data_shape, + 'type_dict': {'take_indices': np.float64, + 'take_a': np.float64}}, + {'ctx': mx.cpu(0), 'take_indices': idx_shape, + 'take_a': data_shape, + 'type_dict': {'take_indices': np.float32, + 'take_a': np.float32}}, + {'ctx': mx.cpu(0), 'take_indices': idx_shape, + 'take_a': data_shape, + 'type_dict': {'take_indices': np.float16, + 'take_a': np.float16}}] + arg_params = {'take_indices': np.random.randint(low=0, + high=data_shape[0], + size=idx_shape), + 'take_a': np.random.normal(size=data_shape)} + with environment('MXNET_SAFE_ACCUMULATION', safe_accumulation): + check_consistency(sym, ctx_list, + grad_req={'take_indices': 'null', + 'take_a': 'write'}, + arg_params=arg_params) + + # check a large num of indices: may underflow calculating gradient in FP16, + # if MXNET_SAFE_ACCUMULATION is not activated + with environment('MXNET_SAFE_ACCUMULATION', '1'): + data_size = 4 + indices_size = 10000 + out_dim = 20 + data_types = [np.float16, np.float32, np.float64] + indices_types = [np.float16, np.float32, np.float64, np.int32] + # axis 0 + sym = mx.sym.take(name='take', axis=0) + ctx_list = [] + for data_type in data_types: + for index_type in indices_types: + ctx_list.append({'ctx': mx.cpu(0), 'take_indices': (indices_size,), + 'take_a': (data_size, out_dim), + 'type_dict': {'take_indices': index_type, 'take_a': data_type}}) + ctx_list.append({'ctx': mx.gpu(0), 'take_indices': (indices_size,), + 'take_a': (data_size, out_dim), + 'type_dict': {'take_indices': index_type, 'take_a': data_type}}) + arg_params = {'take_indices': np.random.randint(0, data_size, + size=(indices_size,)), + 'take_a': np.random.normal(size=(data_size, out_dim))} + check_consistency(sym, ctx_list, + grad_req={'take_indices': 'null','take_a': 'write'}, + arg_params=arg_params) + # axis 1 + sym = mx.sym.take(name='take', axis=1) + ctx_list = [] + for data_type in data_types: + for index_type in indices_types: + ctx_list.append({'ctx': mx.cpu(0), 'take_indices': (indices_size,), + 'take_a': (data_size, out_dim), + 'type_dict': {'take_indices': index_type, 'take_a': data_type}}) + ctx_list.append({'ctx': mx.gpu(0), 'take_indices': (indices_size,), + 'take_a': (data_size, out_dim), + 'type_dict': {'take_indices': index_type, 'take_a': data_type}}) + arg_params = {'take_indices': np.random.randint(0, data_size, + size=(indices_size,)), + 'take_a': np.random.normal(size=(data_size, out_dim))} + check_consistency(sym, ctx_list, + grad_req={'take_indices': 'null','take_a': 'write'}, + arg_params=arg_params) @with_seed() @pytest.mark.serial