Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fixing batch_norm and layer_norm for large tensors (#17805) (#18261)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Kumar Srivastava <[email protected]>

Co-authored-by: Rohit Kumar Srivastava <[email protected]>
  • Loading branch information
ChaiBapchya and Rohit Kumar Srivastava committed May 11, 2020
1 parent 80baab8 commit ceb0f06
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
: param.axis);
CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis;

const int channelCount = dshape[channelAxis];
const index_t channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

const int channelCount = dshape[axis];
const index_t channelCount = dshape[axis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
Expand Down

0 comments on commit ceb0f06

Please sign in to comment.