Skip to content

Commit

Permalink
fixing batch_norm and layer_norm for large tensors (apache#17805)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Kumar Srivastava <[email protected]>
  • Loading branch information
2 people authored and ChaiBapchya committed Sep 20, 2020
1 parent 0496690 commit 3f37582
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
: param.axis);
CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis;

const int channelCount = dshape[channelAxis];
const index_t channelCount = dshape[channelAxis];

in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

const int channelCount = dshape[axis];
const index_t channelCount = dshape[axis];

SHAPE_ASSIGN_CHECK(*in_shape,
layernorm::kGamma,
Expand Down

0 comments on commit 3f37582

Please sign in to comment.