Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Embedding gradient performance optimization on GPU #16355

Merged
merged 7 commits into from
Oct 5, 2019
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,245 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
});
}

/*
* \brief returns integer log2(a) rounded up
*/
inline int ilog2(unsigned int a) {
int k = 1;
while (a >>= 1) k++;
return k;
}
ptrendx marked this conversation as resolved.
Show resolved Hide resolved

/*
* \brief finds the lower and upper-bound positions of each unique element within a sorted input array
* \param sorted_data input elements previously sorted
* \param bounds output containing all lower-bound followed by all upper-bound positions
* \param data_dim total number of elements in the input array
* \param vocab_dim maximum number of unique elements
*/
template <typename IType>
__global__ void EmbeddingFindBounds ( const IType *sorted_data,
IType *bounds,
const index_t data_dim,
const index_t vocab_dim) {
const index_t id = blockIdx.x * blockDim.x + threadIdx.x;

// Binary search to find lower bound: stored at bounds[0..vocab_dim-1]
IType lower_bound = 0;
IType upper_bound = data_dim - 1;
IType mean;
while (lower_bound < upper_bound) {
mean = (lower_bound + upper_bound) / 2;
if (id <= sorted_data[mean])
upper_bound = mean;
else
lower_bound = mean + 1;
}
bool found_row = (sorted_data[lower_bound] == id);

if (id < vocab_dim){
bounds[id] = (found_row) ? lower_bound : -1;
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
}

// Binary search to find upper bound: stored at bounds[vocab_dim..2*vocab_dim-1]
lower_bound = 0;
upper_bound = data_dim - 1;
while (lower_bound < upper_bound) {
mean = (lower_bound + upper_bound + 1) / 2;
if (id >= sorted_data[mean])
lower_bound = mean;
else
upper_bound = mean - 1;
}
found_row = (sorted_data[upper_bound]==id);

if (id<vocab_dim){
bounds[vocab_dim + id] = (found_row) ? upper_bound : -1;
}
}

/*
* \brief kernel to compute gradient of EmbeddingOp
* \param grad_in input gradient data
* \param original_index reference to the position at original input data for each index
* \param index_bounds lower and upper-bounds positions of each unique index
* \param grad_out output gradient data
* \param embbedding_dim dimension of the dense embedding
* \param vocab_dim maximum number of unique indices in the data array: tokens vocabulary size
* \param rows_per_block number of grad_in rows to be computed by each block
* \param req write/add/null
*/
template <typename LType, typename DType, typename IType>
__global__ void EmbeddingGradKernel(DType *grad_in,
const IType *original_index,
const IType *index_bounds,
const DType *grad_out,
const index_t embbedding_dim,
const index_t vocab_dim,
const int rows_per_block,
const int req) {

extern __shared__ int sharedmem[];
LType* grad_in_row = (LType*)sharedmem;

// LType has to be bigger than DType, guarded in the launcher code
const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) / sizeof(DType) : 1;
const LType *aligned_grad_out = reinterpret_cast<const LType *>(grad_out);
LType *aligned_grad_in = reinterpret_cast<LType *>(grad_in);
const index_t aligned_emb_dim = embbedding_dim / n_val;
DType *my_grad_in_row = reinterpret_cast<DType *>(&grad_in_row[threadIdx.x]);
LType Lvalue[1];
DType* Dvalues = reinterpret_cast<DType*>(Lvalue);

for (index_t row=0; row < rows_per_block; ++row){
IType my_row = blockIdx.x * rows_per_block + row;
if (my_row < vocab_dim){

// Read lower and upper bounds for current row
IType lower_bound = index_bounds[my_row];
IType upper_bound = index_bounds[vocab_dim + my_row];
int nOccurrences = (lower_bound != -1) ? (upper_bound - lower_bound + 1) : 0;
ptrendx marked this conversation as resolved.
Show resolved Hide resolved

for (index_t emb_id = threadIdx.x; emb_id < aligned_emb_dim; emb_id += blockDim.x){
// Initialize grad_in
if (req == kAddTo) {
grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
}else{
grad_in_row[threadIdx.x] = 0.0;
}
// Add all rows from grad_out according to indices in data
if(nOccurrences){
for (index_t data_idx = lower_bound; data_idx < (lower_bound + nOccurrences); ++data_idx){
*Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id];
for (index_t val_id = 0; val_id < n_val; val_id++){
my_grad_in_row[val_id] += Dvalues[val_id];
}
}
}

// Save results
aligned_grad_in[my_row * aligned_emb_dim + emb_id] = grad_in_row[threadIdx.x];
}
}
}
}

template<typename gpu, typename IType, typename DType>
void EmbeddingGradKernelCaller(const OpContext& ctx,
mshadow::Tensor<gpu, 2, DType> grad_in,
const mshadow::Tensor<gpu, 1, IType>& index,
const mshadow::Tensor<gpu, 2, DType> &grad_out,
const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow::expr;

Stream<gpu> *s = ctx.get_stream<gpu>();
const index_t data_dim = index.shape_[0];
const index_t vocab_dim = grad_in.shape_[0];
const index_t embbedding_dim = grad_in.shape_[1];

// Calculate amount of temporary storage
size_t sort_workspace_size = mxnet::op::SortByKeyWorkspaceSize<int, int, gpu>
(data_dim);
size_t workspace_size = 2 * data_dim * sizeof(int) +
2 * vocab_dim * sizeof(int) + sort_workspace_size;

// Request temporary storage
Tensor<gpu, 1, char> workspace =
ctx.requested[embedding::kTempSpace].get_space_typed<gpu, 1, char>(
Shape1(workspace_size), s);

// Create tensors
size_t pos = 0;
Tensor<gpu, 1, int> sorted_data(reinterpret_cast<int*>(&workspace[pos]),
Shape1(data_dim), s);
pos += data_dim * sizeof(int);
// Reference to input data positions for each element of sorted_data
Tensor<gpu, 1, int> original_index(reinterpret_cast<int*>(&workspace[pos]),
Shape1(data_dim), s);
pos += data_dim * sizeof(int);
// lower and upper bound positions of each index within sorted_data
Tensor<gpu, 1, int> bounds_index(reinterpret_cast<int*>(&workspace[pos]),
Shape1(2 * vocab_dim), s);
pos += 2 * vocab_dim * sizeof(int);
Tensor<gpu, 1, char> Sort_temp_storage(&workspace[pos], Shape1(sort_workspace_size), s);

// Clip indices [0, vocab_dim-1]
Kernel<tcast_clip, gpu>::Launch(s, data_dim, sorted_data.dptr_, index.dptr_,
static_cast<int>(vocab_dim));

Kernel<range_fwd, gpu>::Launch(s, data_dim,
1, 0, 1, kWriteTo, original_index.dptr_);

// Sort indices array
int num_bits = ilog2((vocab_dim - 1));
mxnet::op::SortByKey(sorted_data, original_index, true, &Sort_temp_storage, 0, num_bits);

// Find lower & upper bounds of each possible index
const int threads_block_bounds = 128;
const int nblocks_bounds = (vocab_dim + threads_block_bounds - 1) / threads_block_bounds;
EmbeddingFindBounds<<<nblocks_bounds, threads_block_bounds, 0, Stream<gpu>::GetStream(s)>>>(
sorted_data.dptr_, bounds_index.dptr_, data_dim, vocab_dim);

// Compute Gradient
int ltype = mxnet::common::cuda::get_load_type(embbedding_dim * sizeof(DType));
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
const int rows_per_block = 1;
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
int nelems_per_thread = sizeof(LType) / sizeof(DType);
int threads_block_grad = 32;
int maxThreads = 1024;
while (threads_block_grad < (embbedding_dim/nelems_per_thread) && (threads_block_grad < maxThreads))
threads_block_grad += 32;
size_t required_shared = threads_block_grad * sizeof(LType);
dim3 blocks((vocab_dim + rows_per_block - 1) / rows_per_block, 1);
EmbeddingGradKernel<LType><<<blocks, threads_block_grad, required_shared, Stream<gpu>::GetStream(s)>>>(
grad_in.dptr_, original_index.dptr_,
bounds_index.dptr_, grad_out.dptr_,
embbedding_dim, vocab_dim,
rows_per_block, req[embedding::kWeight]);
});
}

template<>
void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req[embedding::kData], kNullOp)
<< "Embedding layer doesn't support calculate data gradient";
if (req[embedding::kWeight] == kNullOp) {
return;
}
CHECK_EQ(outputs[1].type_flag_, inputs[0].type_flag_);

const mxnet::TShape& ishape = inputs[1].shape_;
const mxnet::TShape& oshape = inputs[0].shape_;

Stream<gpu> *s = ctx.get_stream<gpu>();
CHECK_NE(req[embedding::kWeight], kWriteInplace)
<< "Backward of Embedding does not support writing in place.";
MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
Tensor < gpu, 1, IType > data = inputs[1].get_with_shape<gpu, 1, IType>(
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
Tensor<gpu, 2, DType> grad_out = inputs[0].get_with_shape<gpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
Tensor<gpu, 2, DType> grad_in = outputs[1].get<gpu, 2, DType>(s);

if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) {
EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req);
} else {
LOG(FATAL) << "wrong req";
}
});
});
}

NNVM_REGISTER_OP(Embedding)
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", SparseEmbeddingOpForwardEx<gpu>);
Expand Down