Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Safe accumulation for computing gradient in Embedding & Take #18385

Merged
merged 6 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions 3rdparty/mshadow/mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,43 @@ __global__ void AddTakeGradKernel(DstPlan dst,
}
}

template<bool clip, int x_bits, typename DstPlan, typename ATypePlan,
typename SrcPlan1, typename SrcPlan2>
__global__ void AddTakeGradKernel(DstPlan dst,
ATypePlan temp,
SrcPlan1 index, SrcPlan2 src,
index_t ymax, index_t xmax, const int K) {
const unsigned x_size = 1 << x_bits;
const int xindex = blockIdx.x * x_size + threadIdx.x;
__shared__ int ptr;
if (xindex < xmax) {
for (unsigned y = 0; y < K; ++y) {
temp.REval(y, xindex) = dst.Eval(y, xindex);
}
}
for (unsigned y = 0; y < ymax; ++y) {
if (threadIdx.x == 0) {
ptr = index.Eval(0, y);
if (clip) {
if (ptr <= 0) ptr = 0;
else if (ptr >= K) ptr = K - 1;
} else {
ptr %= K;
if (ptr < 0) ptr += K;
}
}
__syncthreads();
if (xindex < xmax) {
temp.REval(ptr, xindex) += src.Eval(y, xindex);
}
}
if (xindex < xmax) {
for (unsigned y = 0; y < K; ++y) {
dst.REval(y, xindex) = temp.Eval(y, xindex);
}
}
}

template<int warp_bits, int SZ, typename DType, typename IdxType>
__global__ void AddTakeGradLargeBatchKernel(DType* dst,
const IdxType *sorted, const IdxType *index,
Expand Down Expand Up @@ -733,6 +770,46 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
}

template<bool clip = true, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
Tensor<gpu, 2, AType> temp,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
CHECK_EQ(src.CheckContiguous(), true);
const int kUnitBits = kMemUnitBits + 1;
dim3 dimBlock(1 << kUnitBits);
dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits);

CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGrad: shape mismatch";
CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGrad: shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "AddTakeGrad");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
const int K = dst.shape_[0];

if (clip) {
AddTakeGradKernel<true, kUnitBits>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(temp),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
} else {
AddTakeGradKernel<false, kUnitBits>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(temp),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1), K);
}
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel);
}

template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
Expand Down
26 changes: 26 additions & 0 deletions 3rdparty/mshadow/mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,19 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
* \param index index to take
* \param src source output
*/
template<bool clip = true, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
Tensor<cpu, 2, AType> temp,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation.
dst[index[i]] += src[i]
* \param dst destination
* \temp temporal storage for safe accumulation
* \param index index to take
* \param src source output
*/
template<bool clip = true, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
Expand All @@ -861,6 +874,19 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
* \param index original index of the sorted indices
* \param src source output
*/
template<bool clip = true, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
Tensor<gpu, 2, AType> temp,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation.
dst[index[i]] += src[i]
* \param dst destination
* \temp temporal storage for safe accumulation
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& sorted,
Expand Down
33 changes: 33 additions & 0 deletions 3rdparty/mshadow/mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,39 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
}
}

// safe accumulation
template<bool clip, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
Tensor<cpu, 2, AType> temp,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src) {
const index_t K = dst.shape_[0];
const index_t C = dst.shape_[1];
for (index_t j = 0; j < K; ++j) {
for (index_t i = 0; i < C; ++i) {
temp[j][i] = dst[j][i];
}
}
for (index_t y = 0; y < index.size(0); ++y) {
index_t j = index[y];
if (clip) {
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
} else {
j %= K;
if (j < 0) j += K;
}
for (index_t i = 0; i < C; ++i) {
temp[j][i] += src[y][i];
}
}
for (index_t j = 0; j < K; ++j) {
for (index_t i = 0; i < C; ++i) {
dst[j][i] = temp[j][i];
}
}
}

template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& sorted,
Expand Down
8 changes: 8 additions & 0 deletions 3rdparty/mshadow/mshadow/tensor_gpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
}

template<bool clip, typename IndexType, typename DType, typename AType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
Tensor<gpu, 2, AType> temp,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGrad<clip, IndexType, DType>(dst, temp, index, src);
}

template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
Expand Down
84 changes: 52 additions & 32 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -684,27 +684,25 @@ __global__ void EmbeddingFindBounds(const IType *sorted_data,
* \param grad_out output gradient data
* \param embbedding_dim dimension of the dense embedding
* \param vocab_dim maximum number of unique indices in the data array: tokens vocabulary size
* \param nelems_per_load number of elements per each load based on (LType / DType)
* \param req write/add/null
*/
template <typename LType, typename DType, typename IType>
template <typename AType, typename LType, typename DType, typename IType>
__global__ void EmbeddingGradKernel(DType *grad_in,
const IType *original_index,
const IType *index_bounds,
const DType *grad_out,
const index_t embbedding_dim,
const index_t vocab_dim,
const int req) {
const IType *original_index,
const IType *index_bounds,
const DType *grad_out,
const index_t embbedding_dim,
const index_t vocab_dim,
const int nelems_per_load,
const int req) {
extern __shared__ int sharedmem[];
LType* grad_in_row = reinterpret_cast<LType *>(sharedmem);

// LType has to be bigger than DType, guarded in the launcher code
const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) / sizeof(DType) : 1;
AType* grad_in_row = reinterpret_cast<AType *>(sharedmem);
const LType *aligned_grad_out = reinterpret_cast<const LType *>(grad_out);
LType *aligned_grad_in = reinterpret_cast<LType *>(grad_in);
const index_t aligned_emb_dim = embbedding_dim / n_val;
DType *my_grad_in_row = reinterpret_cast<DType *>(&grad_in_row[threadIdx.x]);
LType Lvalue[1];
DType* Dvalues = reinterpret_cast<DType*>(Lvalue);
const index_t aligned_emb_dim = embbedding_dim / nelems_per_load;
LType load_value[1];
DType* data_values = reinterpret_cast<DType*>(load_value);

IType my_row = blockIdx.x;
if (my_row < vocab_dim) {
Expand All @@ -716,29 +714,37 @@ __global__ void EmbeddingGradKernel(DType *grad_in,
for (index_t emb_id=threadIdx.x; emb_id < aligned_emb_dim; emb_id += blockDim.x) {
// Initialize grad_in
if (req == kAddTo) {
grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
*load_value = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
grad_in_row[val_id * blockDim.x + threadIdx.x] = static_cast<AType>(data_values[val_id]);
}
} else {
grad_in_row[threadIdx.x] = 0.0;
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
grad_in_row[val_id * blockDim.x + threadIdx.x] = static_cast<AType>(0.0);
}
}
// Add all rows from grad_out according to indices in data
for (index_t data_idx=lower_bound; data_idx < (lower_bound + nOccurrences); ++data_idx) {
*Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id];
for (index_t val_id = 0; val_id < n_val; val_id++) {
my_grad_in_row[val_id] += Dvalues[val_id];
*load_value = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id];
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
grad_in_row[val_id * blockDim.x + threadIdx.x] += static_cast<AType>(data_values[val_id]);
}
}
// Save results
aligned_grad_in[my_row * aligned_emb_dim + emb_id] = grad_in_row[threadIdx.x];
for (index_t val_id = 0; val_id < nelems_per_load; val_id++) {
data_values[val_id] = static_cast<DType>(grad_in_row[val_id * blockDim.x + threadIdx.x]);
}
aligned_grad_in[my_row * aligned_emb_dim + emb_id] = *load_value;
}
}
}

template<typename gpu, typename IType, typename DType>
template<typename AType, typename IType, typename DType>
void EmbeddingGradKernelCaller(const OpContext& ctx,
mshadow::Tensor<gpu, 2, DType> grad_in,
const mshadow::Tensor<gpu, 1, IType>& index,
const mshadow::Tensor<gpu, 2, DType> &grad_out,
const std::vector<OpReqType>& req) {
mshadow::Tensor<gpu, 2, DType> grad_in,
const mshadow::Tensor<gpu, 1, IType>& index,
const mshadow::Tensor<gpu, 2, DType> &grad_out,
const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow::expr;

Expand Down Expand Up @@ -792,20 +798,23 @@ void EmbeddingGradKernelCaller(const OpContext& ctx,

// Compute Gradient
int ltype = mxnet::common::cuda::get_load_type(embbedding_dim * sizeof(DType));

MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
int nelems_per_thread = sizeof(LType) / sizeof(DType);
CHECK_LE(sizeof(DType), sizeof(LType));
int nelems_per_load = sizeof(LType) / sizeof(DType);
int threads_block_grad = 32;
int maxThreads = 1024;
while (threads_block_grad < (embbedding_dim/nelems_per_thread) &&
while (threads_block_grad < (embbedding_dim/nelems_per_load) &&
(threads_block_grad < maxThreads))
threads_block_grad += 32;
size_t required_shared = threads_block_grad * sizeof(LType);
size_t required_shared = threads_block_grad * nelems_per_load * sizeof(AType);
dim3 blocks(vocab_dim, 1);
EmbeddingGradKernel<LType><<<blocks, threads_block_grad, required_shared,
EmbeddingGradKernel<AType, LType><<<blocks, threads_block_grad, required_shared,
Stream<gpu>::GetStream(s)>>>(
grad_in.dptr_, original_index.dptr_,
bounds_index.dptr_, grad_out.dptr_,
embbedding_dim, vocab_dim,
nelems_per_load,
req[embedding::kWeight]);
});
}
Expand All @@ -831,9 +840,17 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& oshape = inputs[0].shape_;

Stream<gpu> *s = ctx.get_stream<gpu>();

CHECK_NE(req[embedding::kWeight], kWriteInplace)
<< "Backward of Embedding does not support writing in place.";
MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for EmbeddingOpBackward "
"with float16 inputs. "
"See https://mxnet.apache.org/api/faq/env_var "
"for more details.");
}
MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, {
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
Tensor < gpu, 1, IType > data = inputs[1].get_with_shape<gpu, 1, IType>(
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
Expand All @@ -842,7 +859,10 @@ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
Tensor<gpu, 2, DType> grad_in = outputs[1].get<gpu, 2, DType>(s);

if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) {
EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req);
if (safe_acc)
EmbeddingGradKernelCaller<AType>(ctx, grad_in, data, grad_out, req);
else
EmbeddingGradKernelCaller<DType>(ctx, grad_in, data, grad_out, req);
} else {
LOG(FATAL) << "wrong req";
}
Expand Down
Loading