Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
making AddTakeGrad as default for backward of embedding and take to a…
Browse files Browse the repository at this point in the history
…void nan (#11795)
  • Loading branch information
haojin2 authored and eric-haibin-lin committed Jul 25, 2018
1 parent 0b8b939 commit fe1c7ab
Showing 1 changed file with 2 additions and 74 deletions.
76 changes: 2 additions & 74 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,46 +548,6 @@ struct tcast_clip {
}
};

template<typename xpu, typename IndexType, typename DType>
void AddTakeGradLargeBatchCaller(const OpContext& ctx, mshadow::Tensor<xpu, 2, DType> dst,
const mshadow::Tensor<xpu, 1, IndexType>& index,
const mshadow::Tensor<xpu, 2, DType> &src) {
using namespace mxnet_op;
using namespace mshadow::expr;

Stream<xpu> *s = ctx.get_stream<xpu>();

// Calculate amount of temporary storage
size_t sort_workspace_size = mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>
(index.shape_.Size());
size_t addtake_workspace_size = mxnet::op::AddTakeGradLargeBatchWorkspaceSize<int, xpu>
(index.shape_.Size());
size_t temp_storage_size = std::max(sort_workspace_size, addtake_workspace_size);
size_t workspace_size = 2*(index.shape_.Size()*sizeof(int)) + temp_storage_size;

// Request temporary storage
Tensor<xpu, 1, char> workspace =
ctx.requested[embedding::kTempSpace].get_space_typed<xpu, 1, char>(
Shape1(workspace_size), s);

// Create tensors
size_t pos = 0;
Tensor<xpu, 1, int> sorted_data(reinterpret_cast<int*>(&workspace[pos]),
Shape1(index.shape_.Size()), s);
pos += index.shape_.Size()*sizeof(int);
Tensor<xpu, 1, int> original_index(reinterpret_cast<int*>(&workspace[pos]),
Shape1(index.shape_.Size()), s);
pos += index.shape_.Size()*sizeof(int);
Tensor<xpu, 1, char> temp_storage(&workspace[pos], Shape1(temp_storage_size), s);
Kernel<tcast_clip, xpu>::Launch(s, index.shape_.Size(), sorted_data.dptr_, index.dptr_,
static_cast<int>(dst.shape_[0]));
Kernel<range_fwd, xpu>::Launch(s, index.shape_.Size(),
1, 0, 1, kWriteTo, original_index.dptr_);
int num_bits = ilog2((dst.shape_[0] - 1));
mxnet::op::SortByKey(sorted_data, original_index, true, &temp_storage, 0, num_bits);
mxnet::op::AddTakeGradLargeBatch(dst, sorted_data, original_index, src, &temp_storage);
}

template<typename xpu>
void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -619,25 +579,7 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs,
if (req[embedding::kWeight] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
// shape_out_prod ~= the number of elements loaded in AddTakeGrad
// shape_in_prod ~= the number of elements stored in AddTakeGrad
// When the number of elements processed is low, use AddTakeGrad.
// The approximate cut-off value 16384 was found experimentally on Titan X Pascal
uint64_t shape_in_prod =
static_cast<uint64_t>(grad_in.shape_[0])*
static_cast<uint64_t>(grad_in.shape_[1]);
uint64_t shape_out_prod =
static_cast<uint64_t>(grad_out.shape_[0])*
static_cast<uint64_t>(grad_out.shape_[1]);

static bool force_addtakegrad =
dmlc::GetEnv("MXNET_FORCE_ADDTAKEGRAD", false);
if (force_addtakegrad || (shape_out_prod < (uint64_t)16384 &&
shape_in_prod < (uint64_t)16384)) {
AddTakeGrad(grad_in, data, grad_out);
} else {
AddTakeGradLargeBatchCaller(ctx, grad_in, data, grad_out);
}
AddTakeGrad(grad_in, data, grad_out);
} else {
LOG(FATAL) << "wrong req";
}
Expand Down Expand Up @@ -1132,21 +1074,7 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
// shape_out_prod ~= the number of elements loaded in AddTakeGrad
// shape_in_prod ~= the number of elements stored in AddTakeGrad
// When the number of elements processed is low, use AddTakeGrad.
// The approximate cut-off value 16384 was found experimentally on Titan X Pascal
uint64_t shape_in_prod =
static_cast<uint64_t>(grad_in.shape_[0])*
static_cast<uint64_t>(grad_in.shape_[1]);
uint64_t shape_out_prod =
static_cast<uint64_t>(grad_out.shape_[0])*
static_cast<uint64_t>(grad_out.shape_[1]);
if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) {
AddTakeGrad(grad_in, idx, grad_out);
} else {
AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out);
}
AddTakeGrad(grad_in, idx, grad_out);
} else {
LOG(FATAL) << "wrong req";
}
Expand Down

0 comments on commit fe1c7ab

Please sign in to comment.