From 3f375829f64e117ad22159fe0fc776902ecacddb Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Mon, 16 Mar 2020 16:02:05 -0700 Subject: [PATCH] fixing batch_norm and layer_norm for large tensors (#17805) Co-authored-by: Rohit Kumar Srivastava --- src/operator/nn/batch_norm.cc | 2 +- src/operator/nn/layer_norm.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 60f955399877..1bbdfa63160f 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -374,7 +374,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, : param.axis); CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis; - const int channelCount = dshape[channelAxis]; + const index_t channelCount = dshape[channelAxis]; in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount)); diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index c3ccd0d7a6bc..11178b358c2d 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -51,7 +51,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; - const int channelCount = dshape[axis]; + const index_t channelCount = dshape[axis]; SHAPE_ASSIGN_CHECK(*in_shape, layernorm::kGamma,