diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 75fa4e20b1de..2a643a266b2b 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -492,24 +492,6 @@ class DropoutOp { #endif // MXNET_USE_CUDNN_DROPOUT }; // class DropoutOp -inline OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, - const Context ctx, - const mxnet::ShapeVector &in_shapes, - const std::vector &in_types) { - const DropoutParam& param = nnvm::get(attrs.parsed); - OpStatePtr state; - MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { - if (ctx.dev_type == kGPU) { - state = OpStatePtr::Create>(param, ctx); - } else { - state = OpStatePtr::Create>(param, ctx); - } - return state; - }); - LOG(FATAL) << "should never reach here"; - return OpStatePtr(); // should never reach here -} - template void DropoutCompute(const OpStatePtr& state, const OpContext& ctx, diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index afad6fd5cc80..5ed97b985070 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -26,6 +26,32 @@ #include "./dropout-inl.h" #include "../operator_common.h" +#include "mxnet/op_attr_types.h" + +namespace { + +using namespace mxnet; +using namespace mxnet::op; +OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + const auto& param = nnvm::get(attrs.parsed); + OpStatePtr state; + MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); + } + return state; + }); + LOG(FATAL) << "should never reach here"; + return OpStatePtr(); // should never reach here +} + +} // anonymous namespace + namespace mxnet { namespace op {