diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 75fa4e20b1de..2a643a266b2b 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -492,24 +492,6 @@ class DropoutOp { #endif // MXNET_USE_CUDNN_DROPOUT }; // class DropoutOp -inline OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, - const Context ctx, - const mxnet::ShapeVector &in_shapes, - const std::vector &in_types) { - const DropoutParam& param = nnvm::get(attrs.parsed); - OpStatePtr state; - MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { - if (ctx.dev_type == kGPU) { - state = OpStatePtr::Create>(param, ctx); - } else { - state = OpStatePtr::Create>(param, ctx); - } - return state; - }); - LOG(FATAL) << "should never reach here"; - return OpStatePtr(); // should never reach here -} - template void DropoutCompute(const OpStatePtr& state, const OpContext& ctx, diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index afad6fd5cc80..4b00bf9dac83 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -27,6 +27,29 @@ #include "./dropout-inl.h" #include "../operator_common.h" +namespace { + +OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + const DropoutParam& param = nnvm::get(attrs.parsed); + OpStatePtr state; + MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); + } + return state; + }); + LOG(FATAL) << "should never reach here"; + return OpStatePtr(); // should never reach here +} + +} // anonymous namespace + + namespace mxnet { namespace op {