diff --git a/benchmark/opperf/utils/op_registry_utils.py b/benchmark/opperf/utils/op_registry_utils.py index 65eb6aab2aac..a18eac4e0315 100644 --- a/benchmark/opperf/utils/op_registry_utils.py +++ b/benchmark/opperf/utils/op_registry_utils.py @@ -52,7 +52,7 @@ def _select_ops(operator_names, filters=("_contrib", "_"), merge_op_forward_back operators_with_backward = [] # Filter out deprecated operators - filters += ("normal", "uniform", "BatchNorm_v1", "Flatten", "contrib_CTCLoss", "Pad", "Cast", + filters += ("normal", "uniform", "Flatten", "contrib_CTCLoss", "Pad", "Cast", "Pooling_v1", "Concat", "Reshape", "Convolution_v1", "SliceChannel", "Crop", "crop", "onehot_encode", "batch_take") diff --git a/python/mxnet/contrib/amp/lists/symbol_bf16.py b/python/mxnet/contrib/amp/lists/symbol_bf16.py index 86edfe6fde8d..d363d3648d70 100644 --- a/python/mxnet/contrib/amp/lists/symbol_bf16.py +++ b/python/mxnet/contrib/amp/lists/symbol_bf16.py @@ -57,7 +57,6 @@ FP32_FUNCS = [ 'Deconvolution', 'RNN', - 'BatchNorm_v1', 'BilinearSampler', 'BlockGrad', 'Cast', diff --git a/python/mxnet/contrib/amp/lists/symbol_fp16.py b/python/mxnet/contrib/amp/lists/symbol_fp16.py index d501a7d6c5b5..a99f596f0018 100644 --- a/python/mxnet/contrib/amp/lists/symbol_fp16.py +++ b/python/mxnet/contrib/amp/lists/symbol_fp16.py @@ -32,7 +32,6 @@ # are dtype neutral (can work in both fp16 and fp32) FP16_FP32_FUNCS = [ 'BatchNorm', - 'BatchNorm_v1', 'BilinearSampler', 'BlockGrad', 'Cast', diff --git a/src/operator/batch_norm_v1-inl.h b/src/operator/batch_norm_v1-inl.h deleted file mode 100644 index 1520df93c0e3..000000000000 --- a/src/operator/batch_norm_v1-inl.h +++ /dev/null @@ -1,380 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015 by Contributors - * \file batch_norm_v1-inl.h - * \brief - * \author Bing Xu -*/ -#ifndef MXNET_OPERATOR_BATCH_NORM_V1_INL_H_ -#define MXNET_OPERATOR_BATCH_NORM_V1_INL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" -#include "./mshadow_op.h" - -namespace mxnet { -namespace op { - -namespace batchnorm_v1 { -enum BatchNormOpInputs {kData, kGamma, kBeta}; -enum BatchNormOpOutputs {kOut, kMean, kVar}; -enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; -enum BatchNormBackResource {kTempSpace}; -} // namespace batchnorm_v1 - -struct BatchNormV1Param : public dmlc::Parameter { - float eps; - float momentum; - bool fix_gamma; - bool use_global_stats; - bool output_mean_var; - DMLC_DECLARE_PARAMETER(BatchNormV1Param) { - DMLC_DECLARE_FIELD(eps).set_default(1e-3f) - .describe("Epsilon to prevent div 0"); - DMLC_DECLARE_FIELD(momentum).set_default(0.9f) - .describe("Momentum for moving average"); - DMLC_DECLARE_FIELD(fix_gamma).set_default(true) - .describe("Fix gamma while training"); - DMLC_DECLARE_FIELD(use_global_stats).set_default(false) - .describe("Whether use global moving statistics instead of local batch-norm. " - "This will force change batch-norm into a scale shift operator."); - DMLC_DECLARE_FIELD(output_mean_var).set_default(false) - .describe("Output All,normal mean and var"); - } -}; - -template -class BatchNormV1Op : public Operator { - public: - explicit BatchNormV1Op(BatchNormV1Param param) { - this->param_ = param; - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(aux_states.size(), 2U); - if (ctx.is_train) { - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(req.size(), 3U); - } else { - CHECK_GE(out_data.size(), 1U); - CHECK_GE(req.size(), 1U); - CHECK_EQ(req[batchnorm_v1::kOut], kWriteTo); - } - - Stream *s = ctx.get_stream(); - const real_t scale = static_cast(in_data[batchnorm_v1::kData].shape_[1]) / - static_cast(in_data[batchnorm_v1::kData].shape_.Size()); - Tensor data; - Tensor out; - if (in_data[batchnorm_v1::kData].ndim() == 2) { - Shape<4> dshape = Shape4(in_data[batchnorm_v1::kData].shape_[0], - in_data[batchnorm_v1::kData].shape_[1], 1, 1); - data = in_data[batchnorm_v1::kData].get_with_shape(dshape, s); - out = out_data[batchnorm_v1::kOut].get_with_shape(dshape, s); - } else { - data = in_data[batchnorm_v1::kData].get(s); - out = out_data[batchnorm_v1::kOut].get(s); - } - Tensor slope = in_data[batchnorm_v1::kGamma].get(s); - Tensor bias = in_data[batchnorm_v1::kBeta].get(s); - Tensor moving_mean = aux_states[batchnorm_v1::kMovingMean].get(s); - Tensor moving_var = aux_states[batchnorm_v1::kMovingVar].get(s); - - if (param_.fix_gamma) slope = 1.f; - - // whether use global statistics - if (ctx.is_train && !param_.use_global_stats) { - Tensor mean = out_data[batchnorm_v1::kMean].get(s); - Tensor var = out_data[batchnorm_v1::kVar].get(s); - CHECK(req[batchnorm_v1::kMean] == kNullOp || req[batchnorm_v1::kMean] == kWriteTo); - CHECK(req[batchnorm_v1::kVar] == kNullOp || req[batchnorm_v1::kVar] == kWriteTo); - // The first three steps must be enforced. - mean = scale * sumall_except_dim<1>(data); - var = scale * sumall_except_dim<1>(F( - data - broadcast<1>(mean, data.shape_))); - Assign(out, req[batchnorm_v1::kOut], broadcast<1>(slope, out.shape_) * - (data - broadcast<1>(mean, data.shape_)) / - F(broadcast<1>(var + param_.eps, data.shape_)) + - broadcast<1>(bias, out.shape_)); - } else { - Assign(out, req[batchnorm_v1::kOut], broadcast<1>(slope / - F(moving_var + param_.eps), - data.shape_) * data + - broadcast<1>(bias - (slope * moving_mean) / - F(moving_var + param_.eps), data.shape_)); - // Set mean and var tensors to their moving values - Tensor mean = out_data[batchnorm_v1::kMean].get(s); - Tensor var = out_data[batchnorm_v1::kVar].get(s); - mean = F(moving_mean); - var = F(moving_var); - } - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), param_.output_mean_var ? 3U : 1U); - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(in_grad.size(), 3U); - Stream *s = ctx.get_stream(); - Tensor data, grad, grad_in; - const real_t scale = static_cast(out_grad[batchnorm_v1::kOut].shape_[1]) / - static_cast(out_grad[batchnorm_v1::kOut].shape_.Size()); - if (in_data[batchnorm_v1::kData].ndim() == 2) { - Shape<4> dshape = Shape4(out_grad[batchnorm_v1::kOut].shape_[0], - out_grad[batchnorm_v1::kOut].shape_[1], 1, 1); - data = in_data[batchnorm_v1::kData].get_with_shape(dshape, s); - grad = out_grad[batchnorm_v1::kOut].get_with_shape(dshape, s); - grad_in = in_grad[batchnorm_v1::kData].get_with_shape(dshape, s); - } else { - data = in_data[batchnorm_v1::kData].get(s); - grad = out_grad[batchnorm_v1::kOut].get(s); - grad_in = in_grad[batchnorm_v1::kData].get(s); - } - - Tensor mean = out_data[batchnorm_v1::kMean].get(s); - Tensor var = out_data[batchnorm_v1::kVar].get(s); - Tensor slope = in_data[batchnorm_v1::kGamma].get(s); - // Tensor bias = in_data[kBeta].get(s); - Tensor gslope = in_grad[batchnorm_v1::kGamma].get(s); - Tensor gbias = in_grad[batchnorm_v1::kBeta].get(s); - // update moving avg - Tensor moving_mean = aux_states[batchnorm_v1::kMovingMean].get(s); - Tensor moving_var = aux_states[batchnorm_v1::kMovingVar].get(s); - - if (param_.fix_gamma) slope = 1.f; - - if (ctx.is_train && !param_.use_global_stats) { - // get requested temp space - Tensor workspace = ctx.requested[batchnorm_v1::kTempSpace].get_space( - mshadow::Shape2(3, mean.shape_[0]), s); - Tensor gmean = workspace[0]; - Tensor gvar = workspace[1]; - Tensor tmp = workspace[2]; - - moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum); - moving_var = moving_var * param_.momentum + var * (1 - param_.momentum); - // cal - gvar = sumall_except_dim<1>((grad * broadcast<1>(slope, data.shape_)) * - (data - broadcast<1>(mean, data.shape_)) * - -0.5f * - F(broadcast<1>(var + param_.eps, data.shape_), - -1.5f)); - gmean = sumall_except_dim<1>(grad * broadcast<1>(slope, data.shape_)); - gmean *= -1.0f / F(var + param_.eps); - tmp = scale * sumall_except_dim<1>(-2.0f * (data - broadcast<1>(mean, data.shape_))); - tmp *= gvar; - gmean += tmp; - // assign - if (!param_.fix_gamma) { - Assign(gslope, req[batchnorm_v1::kGamma], - sumall_except_dim<1>( - grad * (data - broadcast<1>(mean, data.shape_)) / - F(broadcast<1>(var + param_.eps, data.shape_)))); - } else { - Assign(gslope, req[batchnorm_v1::kGamma], 0.0f); - } - Assign(grad_in, req[batchnorm_v1::kData], - (grad * broadcast<1>(slope, data.shape_)) * - broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + - broadcast<1>(gvar, data.shape_) * scale * 2.0f * (data - broadcast<1>(mean, - data.shape_)) + - broadcast<1>(gmean, data.shape_) * scale); - Assign(gbias, req[batchnorm_v1::kBeta], sumall_except_dim<1>(grad)); - } else { - // use global statistics with freeze moving mean and var. - if (!param_.fix_gamma) { - Assign(gslope, req[batchnorm_v1::kGamma], - sumall_except_dim<1>( - grad * (data - broadcast<1>(moving_mean, data.shape_)) / - F(broadcast<1>(moving_var + param_.eps, data.shape_)))); - } else { - Assign(gslope, req[batchnorm_v1::kGamma], 0.0f); - } - Assign(gbias, req[batchnorm_v1::kBeta], sumall_except_dim<1>(grad)); - Assign(grad_in, req[batchnorm_v1::kData], (grad * broadcast<1>(slope, data.shape_)) * - broadcast<1>( - 1.0f / F(moving_var + param_.eps), data.shape_)); - } - } - - private: - BatchNormV1Param param_; -}; // class BatchNormV1Op - -template -Operator *CreateOp(BatchNormV1Param param, int dtype); - - -#if DMLC_USE_CXX11 -class BatchNormV1Prop : public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; - const mxnet::TShape &dshape = in_shape->at(0); - if (!mxnet::ndim_is_known(dshape)) return false; - in_shape->at(1) = mxnet::TShape(Shape1(dshape[1])); - in_shape->at(2) = mxnet::TShape(Shape1(dshape[1])); - out_shape->clear(); - out_shape->push_back(dshape); - out_shape->push_back(Shape1(dshape[1])); - out_shape->push_back(Shape1(dshape[1])); - - aux_shape->clear(); - aux_shape->push_back(Shape1(dshape[1])); - aux_shape->push_back(Shape1(dshape[1])); - return true; - } - - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { - using namespace mshadow; - CHECK_GE(in_type->size(), 1U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; - // For float16 input type beta, gamma, mean, and average are stored in float32. - // For other input types, these parameters have the same type as input - // NOTE: This requirement is from cuDNN (v. 4 and 5) - int dtype_param = (dtype == kFloat16) ? kFloat32 : dtype; - for (size_t i = 1; i < in_type->size(); ++i) { - if ((*in_type)[i] == -1) { - (*in_type)[i] = dtype_param; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]); - } - } - for (size_t i = 0; i < aux_type->size(); ++i) { - if ((*aux_type)[i] != -1) { - UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]); - } - } - int n_aux = this->ListAuxiliaryStates().size(); - aux_type->clear(); - for (int i = 0; i < n_aux; ++i ) aux_type->push_back(dtype_param); - int n_out = this->ListOutputs().size(); - out_type->clear(); - out_type->push_back(dtype); - for (int i = 1; i < n_out; ++i ) out_type->push_back(dtype_param); - return true; - } - - OperatorProperty* Copy() const override { - auto ptr = new BatchNormV1Prop(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "BatchNorm_v1"; - } - - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {out_grad[batchnorm_v1::kOut], - out_data[batchnorm_v1::kMean], - out_data[batchnorm_v1::kVar], - in_data[batchnorm_v1::kData], - in_data[batchnorm_v1::kGamma] - }; - } - - std::vector BackwardResource( - const mxnet::ShapeVector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - int NumVisibleOutputs() const override { - if (param_.output_mean_var) { - return 3; - } - return 1; - } - - int NumOutputs() const override { - return 3; - } - - std::vector ListArguments() const override { - return {"data", "gamma", "beta"}; - } - - std::vector ListOutputs() const override { - return {"output", "mean", "var"}; - } - - std::vector ListAuxiliaryStates() const override { - return {"moving_mean", "moving_var"}; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not Implemented."; - return nullptr; - } - - Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; - - inline const BatchNormV1Param& getParam() const { - return param_; - } - - private: - BatchNormV1Param param_; -}; // class BatchNormV1Prop - -#endif // DMLC_USE_CXX11 -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_BATCH_NORM_V1_INL_H_ diff --git a/src/operator/batch_norm_v1.cc b/src/operator/batch_norm_v1.cc deleted file mode 100644 index c837a5e28b9e..000000000000 --- a/src/operator/batch_norm_v1.cc +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015 by Contributors - * \file batch_norm_v1.cc - * \brief - * \author Bing Xu -*/ - -#include "batch_norm_v1-inl.h" -#include - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(BatchNormV1Param param, int dtype) { - return new BatchNormV1Op(param); -} - -// DO_BIND_DISPATCH comes from operator_common.h -Operator *BatchNormV1Prop::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const { - mxnet::ShapeVector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); -} - -DMLC_REGISTER_PARAMETER(BatchNormV1Param); - -MXNET_REGISTER_OP_PROPERTY(BatchNorm_v1, BatchNormV1Prop) -.describe(R"code(Batch normalization. - -This operator is DEPRECATED. Perform BatchNorm on the input. - -Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as -well as offset ``beta``. - -Assume the input has more than one dimension and we normalize along axis 1. -We first compute the mean and variance along this axis: - -.. math:: - - data\_mean[i] = mean(data[:,i,:,...]) \\ - data\_var[i] = var(data[:,i,:,...]) - -Then compute the normalized output, which has the same shape as input, as following: - -.. math:: - - out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i] - -Both *mean* and *var* returns a scalar by treating the input as a vector. - -Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` -have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and -``data_var`` as well, which are needed for the backward pass. - -Besides the inputs and the outputs, this operator accepts two auxiliary -states, ``moving_mean`` and ``moving_var``, which are *k*-length -vectors. They are global statistics for the whole dataset, which are updated -by:: - - moving_mean = moving_mean * momentum + data_mean * (1 - momentum) - moving_var = moving_var * momentum + data_var * (1 - momentum) - -If ``use_global_stats`` is set to be true, then ``moving_mean`` and -``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute -the output. It is often used during inference. - -Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, -then set ``gamma`` to 1 and its gradient to 0. - -There's no sparse support for this operator, and it will exhibit problematic behavior if used with -sparse tensors. - -)code" ADD_FILELINE) -.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") -.add_argument("gamma", "NDArray-or-Symbol", "gamma array") -.add_argument("beta", "NDArray-or-Symbol", "beta array") -.add_arguments(BatchNormV1Param::__FIELDS__()); - -NNVM_REGISTER_OP(BatchNorm_v1) -.set_attr("FSetInputVarAttrOnCompose", - [](const nnvm::NodeAttrs& attrs, nnvm::ObjectPtr var, const int index) { - if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return; - if (index == 3) { - var->attrs.dict["__init__"] = "[\"zero\", {}]"; - } else if (index == 4) { - var->attrs.dict["__init__"] = "[\"one\", {}]"; - } - }); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/batch_norm_v1.cu b/src/operator/batch_norm_v1.cu deleted file mode 100644 index 2adbdef3c716..000000000000 --- a/src/operator/batch_norm_v1.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015 by Contributors - * \file batch_norm_v1.cu - * \brief - * \author Bing Xu -*/ - -#include "batch_norm_v1-inl.h" - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(BatchNormV1Param param, int dtype) { - return new BatchNormV1Op(param); -} - -} // namespace op -} // namespace mxnet - diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index 74c2b546f161..92bd54beffc4 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -27,7 +27,6 @@ #include #include #include "../../src/operator/nn/batch_norm-inl.h" -#include "../../src/operator/batch_norm_v1-inl.h" #include "../../src/operator/operator_common.h" #include "./test_legacy_op.h" #include "./test_core_op.h" diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index b6d0011f1a2f..ff6b7953b41f 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -486,13 +486,6 @@ def test_batchnorm_with_type(): {'ctx': mx.gpu(0), 'norm_data': (3, 2, 3, 2, 3), 'type_dict': {'norm_data': np.float64}} ] - # V1, 2D - sym = mx.sym.BatchNorm_v1(name='norm', fix_gamma=False) - check_consistency(sym, ctx_list_v1_2D) - sym = mx.sym.BatchNorm_v1(name='norm', fix_gamma=True) - check_consistency(sym, ctx_list_v1_2D) - - # V2, 2D sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True) check_consistency(sym, ctx_list_v2_2D) @@ -526,19 +519,6 @@ def test_batchnorm_versions(): def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, use_global_stats): ctx_list = [] sym_list = [] - # BatchNormV1 cpu - if 'batchnorm_v1_cpu' in batchnorm_op_list: - ctx_list.append({'ctx': mx.cpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}}) - sym_list.append(mx.sym.BatchNorm_v1(fix_gamma=fix_gamma, - use_global_stats=use_global_stats, - name='batchnorm')) - - # BatchNormV1 gpu (organic) - if 'batchnorm_v1_gpu' in batchnorm_op_list: - ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}}) - sym_list.append(mx.sym.BatchNorm_v1(fix_gamma=fix_gamma, - use_global_stats=use_global_stats, - name='batchnorm')) # BatchNorm cpu if 'batchnorm_cpu' in batchnorm_op_list: diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0e4405379c13..ef37c506436c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1592,27 +1592,15 @@ def check_batchnorm_training(stype): mx.nd.array(beta).tostype(stype)] mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)] - test = mx.symbol.BatchNorm_v1(data, fix_gamma=True) - check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm_v1(data, fix_gamma=True, use_global_stats=True) - check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm_v1(data, fix_gamma=False) - check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=False) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm_v1(data, fix_gamma=False, use_global_stats=True) - check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)