Skip to content

Commit

Permalink
MXNET_FORCE_ADDTAKEGRAD to disable AddTakeGradLargeBatchCaller (apach…
Browse files Browse the repository at this point in the history
…e#11316)

* MXNET_FORCE_ADDTAKEGRAD to disable AddTakeGradLargeBatchCaller

If MXNET_FORCE_ADDTAKEGRAD is set, EmbeddingOpBackward will always use
AddTakeGrad independently of gradient input and output shape

* Read MXNET_FORCE_ADDTAKEGRAD to a static variable
  • Loading branch information
leezu authored and zheng-da committed Jun 28, 2018
1 parent a8a0829 commit 0f5546d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,11 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs,
uint64_t shape_out_prod =
static_cast<uint64_t>(grad_out.shape_[0])*
static_cast<uint64_t>(grad_out.shape_[1]);
if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) {

static bool default_addtakegrad =
dmlc::GetEnv("MXNET_FORCE_ADDTAKEGRAD", false);
if (!default_addtakegrad || (shape_out_prod < (uint64_t)16384 &&
shape_in_prod < (uint64_t)16384)) {
AddTakeGrad(grad_in, data, grad_out);
} else {
AddTakeGradLargeBatchCaller(ctx, grad_in, data, grad_out);
Expand Down

0 comments on commit 0f5546d

Please sign in to comment.