From bfea5094ed46e0bcf4a45725e8b4c86d5404e6ad Mon Sep 17 00:00:00 2001 From: MoisesHer <50716238+MoisesHer@users.noreply.github.com> Date: Sat, 5 Oct 2019 15:59:36 -0700 Subject: [PATCH] Embedding gradient performance optimization on GPU (#16355) * Add Embedding backward Op for GPU * Add some code documentation * Use unnamed namespace for integer log2 function * Fix lint issues * Fix one more lint problem * Remove unnecessary conditions ops * Fix one more lint problem --- src/operator/tensor/indexing_op.cu | 233 +++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 77d85d8e1e10..0b4c20bf2bb5 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -545,6 +545,239 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, }); } +namespace { + /* + * \brief returns integer log2(a) rounded up + */ + inline int ilog2(unsigned int a) { + int k = 1; + while (a >>= 1) k++; + return k; + } +} + +/* + * \brief finds the lower and upper-bound positions of each unique element within a sorted input array + * \param sorted_data input elements previously sorted + * \param bounds output containing all lower-bound followed by all upper-bound positions + * \param data_dim total number of elements in the input array + * \param vocab_dim maximum number of unique elements + */ +template +__global__ void EmbeddingFindBounds(const IType *sorted_data, + IType *bounds, + const index_t data_dim, + const index_t vocab_dim) { + const index_t id = blockIdx.x * blockDim.x + threadIdx.x; + if (id >= vocab_dim) return; + + // Binary search to find lower bound: stored at bounds[0..vocab_dim-1] + IType lower_bound = 0; + IType upper_bound = data_dim - 1; + IType mean; + while (lower_bound < upper_bound) { + mean = (lower_bound + upper_bound) / 2; + if (id <= sorted_data[mean]) + upper_bound = mean; + else + lower_bound = mean + 1; + } + bool found_row = (sorted_data[lower_bound] == id); + if (!found_row) { + bounds[id] = -1; + bounds[vocab_dim + id] = -2; + return; + } else { + bounds[id] = lower_bound; + } + + // Binary search to find upper bound: stored at bounds[vocab_dim..2*vocab_dim-1] + lower_bound = 0; + upper_bound = data_dim - 1; + while (lower_bound < upper_bound) { + mean = (lower_bound + upper_bound + 1) / 2; + if (id >= sorted_data[mean]) + lower_bound = mean; + else + upper_bound = mean - 1; + } + bounds[vocab_dim + id] = upper_bound; +} + +/* + * \brief kernel to compute gradient of EmbeddingOp + * \param grad_in input gradient data + * \param original_index reference to the position at original input data for each index + * \param index_bounds lower and upper-bounds positions of each unique index + * \param grad_out output gradient data + * \param embbedding_dim dimension of the dense embedding + * \param vocab_dim maximum number of unique indices in the data array: tokens vocabulary size + * \param req write/add/null + */ +template +__global__ void EmbeddingGradKernel(DType *grad_in, + const IType *original_index, + const IType *index_bounds, + const DType *grad_out, + const index_t embbedding_dim, + const index_t vocab_dim, + const int req) { + extern __shared__ int sharedmem[]; + LType* grad_in_row = reinterpret_cast(sharedmem); + + // LType has to be bigger than DType, guarded in the launcher code + const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) / sizeof(DType) : 1; + const LType *aligned_grad_out = reinterpret_cast(grad_out); + LType *aligned_grad_in = reinterpret_cast(grad_in); + const index_t aligned_emb_dim = embbedding_dim / n_val; + DType *my_grad_in_row = reinterpret_cast(&grad_in_row[threadIdx.x]); + LType Lvalue[1]; + DType* Dvalues = reinterpret_cast(Lvalue); + + IType my_row = blockIdx.x; + if (my_row < vocab_dim) { + // Read lower and upper bounds for current row + IType lower_bound = index_bounds[my_row]; + IType upper_bound = index_bounds[vocab_dim + my_row]; + int nOccurrences = upper_bound - lower_bound + 1; + + for (index_t emb_id=threadIdx.x; emb_id < aligned_emb_dim; emb_id += blockDim.x) { + // Initialize grad_in + if (req == kAddTo) { + grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim + emb_id]; + } else { + grad_in_row[threadIdx.x] = 0.0; + } + // Add all rows from grad_out according to indices in data + for (index_t data_idx=lower_bound; data_idx < (lower_bound + nOccurrences); ++data_idx) { + *Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id]; + for (index_t val_id = 0; val_id < n_val; val_id++) { + my_grad_in_row[val_id] += Dvalues[val_id]; + } + } + // Save results + aligned_grad_in[my_row * aligned_emb_dim + emb_id] = grad_in_row[threadIdx.x]; + } + } +} + +template +void EmbeddingGradKernelCaller(const OpContext& ctx, + mshadow::Tensor grad_in, + const mshadow::Tensor& index, + const mshadow::Tensor &grad_out, + const std::vector& req) { + using namespace mxnet_op; + using namespace mshadow::expr; + + Stream *s = ctx.get_stream(); + const index_t data_dim = index.shape_[0]; + const index_t vocab_dim = grad_in.shape_[0]; + const index_t embbedding_dim = grad_in.shape_[1]; + + // Calculate amount of temporary storage + size_t sort_workspace_size = mxnet::op::SortByKeyWorkspaceSize + (data_dim); + size_t workspace_size = 2 * data_dim * sizeof(int) + + 2 * vocab_dim * sizeof(int) + sort_workspace_size; + + // Request temporary storage + Tensor workspace = + ctx.requested[embedding::kTempSpace].get_space_typed( + Shape1(workspace_size), s); + + // Create tensors + size_t pos = 0; + Tensor sorted_data(reinterpret_cast(&workspace[pos]), + Shape1(data_dim), s); + pos += data_dim * sizeof(int); + // Reference to input data positions for each element of sorted_data + Tensor original_index(reinterpret_cast(&workspace[pos]), + Shape1(data_dim), s); + pos += data_dim * sizeof(int); + // lower and upper bound positions of each index within sorted_data + Tensor bounds_index(reinterpret_cast(&workspace[pos]), + Shape1(2 * vocab_dim), s); + pos += 2 * vocab_dim * sizeof(int); + Tensor Sort_temp_storage(&workspace[pos], Shape1(sort_workspace_size), s); + + // Clip indices [0, vocab_dim-1] + Kernel::Launch(s, data_dim, sorted_data.dptr_, index.dptr_, + static_cast(vocab_dim)); + + Kernel::Launch(s, data_dim, + 1, 0, 1, kWriteTo, original_index.dptr_); + + // Sort indices array + int num_bits = ilog2((vocab_dim - 1)); + mxnet::op::SortByKey(sorted_data, original_index, true, &Sort_temp_storage, 0, num_bits); + + // Find lower & upper bounds of each possible index + const int threads_block_bounds = 128; + const int nblocks_bounds = (vocab_dim + threads_block_bounds - 1) / threads_block_bounds; + EmbeddingFindBounds<<::GetStream(s)>>>( + sorted_data.dptr_, bounds_index.dptr_, data_dim, vocab_dim); + + // Compute Gradient + int ltype = mxnet::common::cuda::get_load_type(embbedding_dim * sizeof(DType)); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + int nelems_per_thread = sizeof(LType) / sizeof(DType); + int threads_block_grad = 32; + int maxThreads = 1024; + while (threads_block_grad < (embbedding_dim/nelems_per_thread) && + (threads_block_grad < maxThreads)) + threads_block_grad += 32; + size_t required_shared = threads_block_grad * sizeof(LType); + dim3 blocks(vocab_dim, 1); + EmbeddingGradKernel<<::GetStream(s)>>>( + grad_in.dptr_, original_index.dptr_, + bounds_index.dptr_, grad_out.dptr_, + embbedding_dim, vocab_dim, + req[embedding::kWeight]); + }); +} + +template<> +void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req[embedding::kData], kNullOp) + << "Embedding layer doesn't support calculate data gradient"; + if (req[embedding::kWeight] == kNullOp) { + return; + } + CHECK_EQ(outputs[1].type_flag_, inputs[0].type_flag_); + + const mxnet::TShape& ishape = inputs[1].shape_; + const mxnet::TShape& oshape = inputs[0].shape_; + + Stream *s = ctx.get_stream(); + CHECK_NE(req[embedding::kWeight], kWriteInplace) + << "Backward of Embedding does not support writing in place."; + MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { + Tensor < gpu, 1, IType > data = inputs[1].get_with_shape( + Shape1(ishape.ProdShape(0, ishape.ndim())), s); + Tensor grad_out = inputs[0].get_with_shape( + Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); + Tensor grad_in = outputs[1].get(s); + + if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) { + EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req); + } else { + LOG(FATAL) << "wrong req"; + } + }); + }); +} + NNVM_REGISTER_OP(Embedding) .set_attr("FCompute", EmbeddingOpForward) .set_attr("FComputeEx", SparseEmbeddingOpForwardEx);