diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index acc8d5fac6df..3bb6b81a6de2 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -29,6 +29,7 @@ #include #include #include +#include /*! \brief Macros/inlines to assist CLion to parse Cuda files (*.cu, *.cuh) */ #ifdef __JETBRAINS_IDE__ @@ -482,13 +483,10 @@ static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10, * want to populate. */ inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) { -#if CUDNN_MAJOR >= 7 + STATIC_ASSERT_CUDNN_VERSION_GE(7000); int max_algos = 0; CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos)); return max_algos; -#else - return 10; -#endif } /*! @@ -499,13 +497,10 @@ inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) { * want to populate. */ inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) { -#if CUDNN_MAJOR >= 7 + STATIC_ASSERT_CUDNN_VERSION_GE(7000); int max_algos = 0; CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos)); return max_algos; -#else - return 10; -#endif } /*! @@ -516,13 +511,10 @@ inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) { * want to populate. */ inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) { -#if CUDNN_MAJOR >= 7 + STATIC_ASSERT_CUDNN_VERSION_GE(7000); int max_algos = 0; CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos)); return max_algos; -#else - return 10; -#endif } #endif // MXNET_USE_CUDNN diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc index aa03a1104ede..160ba8fb8d63 100644 --- a/src/executor/attach_op_resource_pass.cc +++ b/src/executor/attach_op_resource_pass.cc @@ -82,12 +82,12 @@ void AttachOpResources( requested.push_back(ResourceManager::Get()->Request(ctx, req)); break; } -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 +#if MXNET_USE_CUDNN == 1 case ResourceRequest::kCuDNNDropoutDesc: { requested.push_back(ResourceManager::Get()->Request(ctx, req)); break; } -#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 +#endif // MXNET_USE_CUDNN == 1 default: LOG(FATAL) << "resource type " << req.type << " is not yet supported"; } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 477139fd84b8..b3e8bdbbe314 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -263,12 +263,12 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs, requested.push_back(ResourceManager::Get()->Request(ctx, req)); write_vars.push_back(requested.back().var); break; -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 +#if MXNET_USE_CUDNN == 1 case ResourceRequest::kCuDNNDropoutDesc: requested.push_back(ResourceManager::Get()->Request(ctx, req)); write_vars.push_back(requested.back().var); break; -#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 +#endif // MXNET_USE_CUDNN == 1 default: LOG(FATAL) << "resource type not yet supported"; } diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu index 03734a61316b..fab1433533ce 100644 --- a/src/operator/bilinear_sampler.cu +++ b/src/operator/bilinear_sampler.cu @@ -27,9 +27,9 @@ #include "./bilinear_sampler-inl.h" #include #include "../common/cuda_utils.h" -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 #include "./cudnn_bilinear_sampler-inl.h" -#endif // MXNET_USE_CUDNN && CUDNN_MAJOR +#endif // MXNET_USE_CUDNN namespace mshadow { namespace cuda { @@ -228,7 +228,7 @@ namespace op { template<> Operator* CreateOp(BilinearSamplerParam param, int dtype) { Operator *op = NULL; -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { if (param.cudnn_off.has_value() && param.cudnn_off.value()) { op = new BilinearSamplerOp(param); @@ -240,7 +240,7 @@ Operator* CreateOp(BilinearSamplerParam param, int dtype) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new BilinearSamplerOp(param); }) -#endif // MXNET_USE_CUDNN && CUDNN_MAJOR +#endif // MXNET_USE_CUDNN return op; } diff --git a/src/operator/cudnn_bilinear_sampler-inl.h b/src/operator/cudnn_bilinear_sampler-inl.h index c2171e6651a6..d72257682edd 100644 --- a/src/operator/cudnn_bilinear_sampler-inl.h +++ b/src/operator/cudnn_bilinear_sampler-inl.h @@ -31,7 +31,8 @@ #include "./bilinear_sampler-inl.h" namespace mxnet { namespace op { -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 +STATIC_ASSERT_CUDNN_VERSION_GE(5000); template class CuDNNBilinearSamplerOp : public Operator { public: @@ -132,9 +133,7 @@ class CuDNNBilinearSamplerOp : public Operator { const std::vector &in_data, const std::vector &out_data) { using namespace mshadow; - #if CUDNN_MAJOR >= 5 format_ = CUDNN_TENSOR_NCHW; - #endif CHECK_EQ(in_data.size(), 2U); CHECK_EQ(out_data.size(), 2U); if (!init_cudnn_) { @@ -174,12 +173,10 @@ class CuDNNBilinearSamplerOp : public Operator { cudnnTensorDescriptor_t in_desc_; cudnnTensorDescriptor_t out_desc_; cudnnSamplerType_t sampler_; - #if CUDNN_MAJOR >= 5 cudnnTensorFormat_t format_; - #endif BilinearSamplerParam param_; }; -#endif // __CUDACC__ && CUDNN +#endif // __CUDACC__ && MXNET_USE_CUDNN } // namespace op } // namespace mxnet diff --git a/src/operator/cudnn_spatial_transformer-inl.h b/src/operator/cudnn_spatial_transformer-inl.h index 1d7242a83c74..2e069e515563 100644 --- a/src/operator/cudnn_spatial_transformer-inl.h +++ b/src/operator/cudnn_spatial_transformer-inl.h @@ -31,7 +31,8 @@ #include "./spatial_transformer-inl.h" namespace mxnet { namespace op { -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 +STATIC_ASSERT_CUDNN_VERSION_GE(5000); template class CuDNNSpatialTransformerOp : public Operator { public: @@ -145,9 +146,7 @@ class CuDNNSpatialTransformerOp : public Operator { const std::vector &in_data, const std::vector &out_data) { using namespace mshadow; - #if CUDNN_MAJOR >= 5 format_ = CUDNN_TENSOR_NCHW; - #endif CHECK_EQ(in_data.size(), 2U); CHECK_EQ(out_data.size(), 3U); if (!init_cudnn_) { @@ -189,12 +188,10 @@ class CuDNNSpatialTransformerOp : public Operator { cudnnTensorDescriptor_t in_desc_; cudnnTensorDescriptor_t out_desc_; cudnnSamplerType_t sampler_; - #if CUDNN_MAJOR >= 5 cudnnTensorFormat_t format_; - #endif SpatialTransformerParam param_; }; -#endif // __CUDACC__ && CUDNN +#endif // __CUDACC__ && MXNET_USE_CUDNN } // namespace op } // namespace mxnet diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 9fb44e8fae81..be9309c8bfb1 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -35,7 +35,7 @@ #define IS_TRAINING_FLAG 16 #define USE_GLOBAL_STATS_FLAG 32 -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 #include "./cudnn/cudnn_batch_norm-inl.h" #endif @@ -641,7 +641,7 @@ void BatchNormBackwardImpl(mshadow::Stream *stream, MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu); } -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4 +#if MXNET_USE_CUDNN == 1 template static CuDNNBatchNormOp &GetCuDNNOp(const BatchNormParam& param) { #if DMLC_CXX11_THREAD_LOCAL @@ -667,7 +667,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape shape = inputs[0].shape_; param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 if (!param.use_global_stats && !param.cudnn_off && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { @@ -696,7 +696,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape shape = inputs[0].shape_; param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 if (!param.use_global_stats && !param.cudnn_off && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index 010be8a208fb..053c3fb6c748 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -94,39 +94,8 @@ void ConvolutionCompute(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param = nnvm::get(attrs.parsed); int dtype = inputs[conv::kData].type_flag_; -#if CUDNN_MAJOR < 5 - if (param.layout.value() != kNCW && - param.layout.value() != kNCHW && - param.layout.value() != kNCDHW) { - // Need CuDNN > 5.0 for layout support. use MXNet implementation - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - ConvolutionOp op; - op.Init(param); - op.Forward(ctx, inputs, req, outputs); - }) - return; - } -#endif - -#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7 - if (param.num_filter == param.num_group && - param.layout.value() == mshadow::kNCHW && - param.num_filter == inputs[conv::kData].shape_[1] && - param.kernel.ndim() == 2 && - param.dilate == mshadow::Shape2(1, 1) && - dtype == mshadow::kFloat32) { - mxnet::ShapeVector in_shape(inputs.size()); - mxnet::ShapeVector out_shape(1, outputs[0].shape_); - for (size_t i = 0; i < in_shape.size(); i++) - in_shape[i] = inputs[i].shape_; - DepthwiseConvolutionOp op; - op.Init(param, in_shape, out_shape); - op.Forward(ctx, inputs, req, outputs); - return; - } -#endif - #if MXNET_USE_CUDNN == 1 + STATIC_ASSERT_CUDNN_VERSION_GE(7000); // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16). int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype; @@ -154,6 +123,22 @@ void ConvolutionCompute(const nnvm::NodeAttrs& attrs, } }) #else + if (param.num_filter == param.num_group && + param.layout.value() == mshadow::kNCHW && + param.num_filter == inputs[conv::kData].shape_[1] && + param.kernel.ndim() == 2 && + param.dilate == mshadow::Shape2(1, 1) && + dtype == mshadow::kFloat32) { + mxnet::ShapeVector in_shape(inputs.size()); + mxnet::ShapeVector out_shape(1, outputs[0].shape_); + for (size_t i = 0; i < in_shape.size(); i++) + in_shape[i] = inputs[i].shape_; + DepthwiseConvolutionOp op; + op.Init(param, in_shape, out_shape); + op.Forward(ctx, inputs, req, outputs); + return; + } + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { ConvolutionOp op; op.Init(param); @@ -174,39 +159,8 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, const std::vector &in_grad = outputs; int dtype = out_grad.type_flag_; -#if CUDNN_MAJOR < 5 - if (param.layout.value() != kNCW && - param.layout.value() != kNCHW && - param.layout.value() != kNCDHW) { - // Need CuDNN > 5.0 for layout support. use MXNet implementation - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - ConvolutionOp op; - op.Init(param); - op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); - }) - return; - } -#endif -#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7 - if (param.num_filter == param.num_group && - param.layout.value() == mshadow::kNCHW && - param.num_filter == in_data[conv::kData].shape_[1] && - param.kernel.ndim() == 2 && - param.dilate == mshadow::Shape2(1, 1) && - dtype == mshadow::kFloat32) { - // The first element stores out grad. - mxnet::ShapeVector in_shape(in_data.size()); - mxnet::ShapeVector out_shape(1, out_grad.shape_); - for (size_t i = 0; i < in_shape.size(); i++) - in_shape[i] = in_data[i].shape_; - DepthwiseConvolutionOp op; - op.Init(param, in_shape, out_shape); - op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); - return; - } -#endif - #if MXNET_USE_CUDNN == 1 + STATIC_ASSERT_CUDNN_VERSION_GE(7000); // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16). int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype; @@ -234,6 +188,23 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, } }) #else + if (param.num_filter == param.num_group && + param.layout.value() == mshadow::kNCHW && + param.num_filter == in_data[conv::kData].shape_[1] && + param.kernel.ndim() == 2 && + param.dilate == mshadow::Shape2(1, 1) && + dtype == mshadow::kFloat32) { + // The first element stores out grad. + mxnet::ShapeVector in_shape(in_data.size()); + mxnet::ShapeVector out_shape(1, out_grad.shape_); + for (size_t i = 0; i < in_shape.size(); i++) + in_shape[i] = in_data[i].shape_; + DepthwiseConvolutionOp op; + op.Init(param, in_shape, out_shape); + op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); + return; + } + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { ConvolutionOp op; op.Init(param); diff --git a/src/operator/nn/cudnn/cudnn_activation-inl.h b/src/operator/nn/cudnn/cudnn_activation-inl.h index 2c1f442808c1..186274b2f1e1 100644 --- a/src/operator/nn/cudnn/cudnn_activation-inl.h +++ b/src/operator/nn/cudnn/cudnn_activation-inl.h @@ -29,18 +29,19 @@ #include #include #include "../activation-inl.h" +#include "../../../common/cuda_utils.h" namespace mxnet { namespace op { template class CuDNNActivationOp { + STATIC_ASSERT_CUDNN_VERSION_GE(5000); + public: CuDNNActivationOp() { dtype_ = mshadow::DataType::kCudnnFlag; - #if CUDNN_MAJOR >= 5 nan_prop_ = CUDNN_NOT_PROPAGATE_NAN; CUDNN_CALL(cudnnCreateActivationDescriptor(&desc_)); - #endif CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc_)); } @@ -60,16 +61,12 @@ class CuDNNActivationOp { LOG(FATAL) << "Not implmented"; break; } - #if CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnSetActivationDescriptor(desc_, mode_, nan_prop_, relu_ceil_)); - #endif } ~CuDNNActivationOp() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc_)); - #if CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnDestroyActivationDescriptor(desc_)); - #endif } void Forward(const OpContext &ctx, const TBlob &in_data, @@ -109,16 +106,6 @@ class CuDNNActivationOp { data.shape_[1], data.shape_[2], data.shape_[3])); - #if CUDNN_MAJOR <= 4 - CUDNN_CALL(cudnnActivationForward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_)); - #elif CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnActivationForward(s->dnn_handle_, desc_, &alpha, @@ -127,7 +114,6 @@ class CuDNNActivationOp { &beta, shape_desc_, out.dptr_)); - #endif } // backward computation for cudnn activation operator. Note that for relu @@ -177,20 +163,6 @@ class CuDNNActivationOp { data.shape_[1], data.shape_[2], data.shape_[3])); - #if CUDNN_MAJOR <= 4 - CUDNN_CALL(cudnnActivationBackward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - input_grad.dptr_)); - #elif CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnActivationBackward(s->dnn_handle_, desc_, &alpha, @@ -203,7 +175,6 @@ class CuDNNActivationOp { &beta, shape_desc_, input_grad.dptr_)); - #endif } private: @@ -211,11 +182,9 @@ class CuDNNActivationOp { cudnnActivationMode_t mode_; cudnnTensorDescriptor_t shape_desc_; ActivationParam param_; -#if CUDNN_MAJOR >= 5 cudnnActivationDescriptor_t desc_; cudnnNanPropagation_t nan_prop_; double relu_ceil_; -#endif }; // class CuDNNActivationOp } // namespace op } // namespace mxnet diff --git a/src/operator/nn/cudnn/cudnn_algoreg-inl.h b/src/operator/nn/cudnn/cudnn_algoreg-inl.h index 3f2d24c5bf7e..f7e01e214719 100644 --- a/src/operator/nn/cudnn/cudnn_algoreg-inl.h +++ b/src/operator/nn/cudnn/cudnn_algoreg-inl.h @@ -44,6 +44,8 @@ namespace op { */ template class CuDNNAlgo { + STATIC_ASSERT_CUDNN_VERSION_GE(7000); + public: CuDNNAlgo() : algo_number_(static_cast(0)), @@ -54,11 +56,9 @@ class CuDNNAlgo { } CuDNNAlgoType AlgoNumber() const { return algo_number_; } bool IsTensorCoreAlgo() const { return is_tensor_core_algo_; } - #if CUDNN_MAJOR >= 7 cudnnMathType_t MathType() { return IsTensorCoreAlgo() ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; } - #endif private: CuDNNAlgoType algo_number_; bool is_tensor_core_algo_; diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 820f8504d74c..3fc91196708c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -34,7 +34,7 @@ namespace mxnet { namespace op { -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4 +#if MXNET_USE_CUDNN == 1 namespace cudnnbatchnorm { enum CuDNNBatchNormOpInputs {kData, kGamma, kBeta}; enum CuDNNBatchNormOpOutputs {kOut, kMean, kInvVar}; @@ -44,6 +44,8 @@ enum CuDNNBatchNormOpAuxiliary {kMovingMean, kMovingInvVar}; #if defined(__CUDACC__) template class CuDNNBatchNormOp { + STATIC_ASSERT_CUDNN_VERSION_GE(5000); + public: CuDNNBatchNormOp() { using namespace mshadow; @@ -182,7 +184,6 @@ class CuDNNBatchNormOp { const bool global_stats = !ctx.is_train || param_.use_global_stats; -#if CUDNN_VERSION >= 4007 #if CUDNN_VERSION >= 7002 auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; #else @@ -229,45 +230,6 @@ class CuDNNBatchNormOp { global_stats ? nullptr : save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; }) -#else // CUDNN_VERSION < 4007 - MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { - Tensor gamma = - in_gamma.get_with_shape(Shape1(shape_[1]), s); - Tensor dbeta = - in_grad[cudnnbatchnorm::kBeta].get_with_shape(Shape1(shape_[1]), s); - Tensor dgamma = - in_grad[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); - Tensor save_mean = - out_mean.get_with_shape(Shape1(shape_[1]), s); - Tensor save_inv_var = - out_var.get_with_shape(Shape1(shape_[1]), s); - - typename DataType::ScaleType a = 1.0f; - typename DataType::ScaleType b = 0.0f; - typename DataType::ScaleType b_add = 1.0f; - CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - - if (param_.fix_gamma) gamma = 1.f; - CUDNN_CALL(cudnnBatchNormalizationBackward(s->dnn_handle_, - CUDNN_BATCHNORM_SPATIAL, - &a, - &b, - io_desc_, - x.dptr_, - io_desc_, - dy.dptr_, - io_desc_, - dx.dptr_, - mean_desc_, - gamma.dptr_, - dgamma.dptr_, - dbeta.dptr_, - param_.eps, - global_stats ? nullptr : save_mean.dptr_, - global_stats ? nullptr : save_inv_var.dptr_)); - if (param_.fix_gamma) dgamma = 0.f; - }) -#endif } private: @@ -303,7 +265,7 @@ class CuDNNBatchNormOp { }; #endif // defined(__CUDACC__) -#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4 +#endif // MXNET_USE_CUDNN == 1 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_ diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cc index cb35ce170e8e..d691b785a6e6 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm.cc +++ b/src/operator/nn/cudnn/cudnn_batch_norm.cc @@ -30,7 +30,7 @@ namespace mxnet { namespace op { -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4 +#if MXNET_USE_CUDNN == 1 static bool BatchNormShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) { @@ -114,7 +114,7 @@ NNVM_REGISTER_OP(_backward_CuDNNBatchNorm) .set_attr_parser(ParamParser) .set_attr("FCompute", BatchNormGradCompute_CPU); -#endif // CUDNN_MAJOR >= 4 +#endif // MXNET_USE_CUDNN } // namespace op } // namespace mxnet diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cu b/src/operator/nn/cudnn/cudnn_batch_norm.cu deleted file mode 100644 index e07cd1e6c8f6..000000000000 --- a/src/operator/nn/cudnn/cudnn_batch_norm.cu +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015 by Contributors - * \file cudnn_batch_norm.cu - * \brief - * \author Junyuan Xie, Da Zheng -*/ - -#include "./cudnn_batch_norm-inl.h" -#include - -namespace mxnet { -namespace op { -#if CUDNN_MAJOR == 4 - -template -static CuDNNBatchNormOp &GetCuDNNOp(const BatchNormParam& param) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local CuDNNBatchNormOp op; -#else - static MX_THREAD_LOCAL CuDNNBatchNormOp op; -#endif - op.Init(param); - return op; -} - -static void BatchNormCompute_CuDNNv4(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { -#if CUDNN_MAJOR >= 5 - LOG(FATAL) << "CuDNNBatchNorm is merged into BatchNorm for cudnn version above v5." - "Use the later instead."; -#else - const BatchNormParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(inputs.size(), 5U); - std::vector in_data(inputs.begin(), inputs.begin() + 3); - std::vector aux_states(inputs.begin() + 3, inputs.end()); - GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); -#endif -} - -static void BatchNormGradCompute_CuDNNv4(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { -#if CUDNN_MAJOR >= 5 - LOG(FATAL) << "CuDNNBatchNorm is merged into BatchNorm for cudnn version above v5." - "Use the later instead."; -#else - CHECK_EQ(inputs.size(), 11U); - const BatchNormParam& param = nnvm::get(attrs.parsed); - std::vector out_grad(1, inputs[0]); - std::vector in_data(inputs.begin() + 3, inputs.begin() + 6); - std::vector aux_states(inputs.begin() + 6, inputs.begin() + 8); - std::vector out_data(inputs.begin() + 8, inputs.end()); - std::vector in_grad(outputs.begin(), outputs.begin() + 3); - GetCuDNNOp(param).Backward(ctx, out_grad, in_data, out_data, - req, in_grad, aux_states); -#endif -} - -NNVM_REGISTER_OP(CuDNNBatchNorm) -.set_attr("FCompute", BatchNormCompute_CuDNNv4); - -NNVM_REGISTER_OP(_backward_CuDNNBatchNorm) -.set_attr("FCompute", BatchNormGradCompute_CuDNNv4); - -#endif // CUDNN_MAJOR == 4 -} // namespace op -} // namespace mxnet - diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 679e0cd1057b..d35e41701918 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -44,6 +44,8 @@ namespace op { */ template class CuDNNConvolutionOp { + STATIC_ASSERT_CUDNN_VERSION_GE(7000); + public: CuDNNConvolutionOp() { CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); @@ -75,7 +77,6 @@ class CuDNNConvolutionOp { // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. cudnn_tensor_core_ = DataType::kFlag == kFloat16 && GetEnvAllowTensorCore(); -#if CUDNN_MAJOR >= 5 auto effective_layout = param_.layout.value(); switch (effective_layout) { // 1D convolutions will be executed as 2D convolutions with a height of 1. @@ -88,14 +89,9 @@ class CuDNNConvolutionOp { MSHADOW_LAYOUT_SWITCH(effective_layout, Layout, { format_ = LayoutType::kCudnnFlag; }); -#else - CHECK(param_.layout.value() == kNCW || - param_.layout.value() == kNCHW || - param_.layout.value() == kNCDHW) << "Need CuDNN > 5.0 for layout support"; -#endif // Double check to make sure this class supports the operation if (!Supports(param, forward_compute_type, backward_compute_type, rctx.ctx.dev_id)) - LOG(FATAL) << "Need CuDNN >= 6.0 for dilated convolution."; + LOG(FATAL) << "Convolution parameters not supported by cuDNN implementation."; InitDescriptors(in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); @@ -141,7 +137,6 @@ class CuDNNConvolutionOp { DType *wmat_ptr = GetNdPtr(in_data[conv::kWeight], param_.kernel.ndim() + 2, s); DType *out_ptr = GetNdPtr(out_data[conv::kOut], param_.kernel.ndim() + 2, s); - #if CUDNN_MAJOR >= 7 typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; typename DataType::ScaleType beta_add = 1.0f; @@ -169,48 +164,6 @@ class CuDNNConvolutionOp { out_desc_, out_ptr)); } - #else - for (uint32_t g = 0; g < param_.num_group; ++g) { - typename DataType::ScaleType alpha = 1.0f; - typename DataType::ScaleType beta = 0.0f; - typename DataType::ScaleType beta_add = 1.0f; - CUDNN_CALL(cudnnConvolutionForward(s->dnn_handle_, - &alpha, - in_desc_, - data_ptr + data_offset_ * g, - filter_desc_, - wmat_ptr + weight_offset_ * g, - forward_conv_desc_, - forward_algo_.AlgoNumber(), - workspace.dptr_, - workspace_size, - req[conv::kOut] == kAddTo? &beta_add : &beta, - out_desc_, - out_ptr + out_offset_ * g)); - if (!param_.no_bias) { - Tensor bias = in_data[conv::kBias].get(s); - #if CUDNN_MAJOR >= 4 - CUDNN_CALL(cudnnAddTensor(s->dnn_handle_, - &alpha, - bias_desc_, - bias.dptr_ + bias_offset_ * g, - &beta_add, - out_desc_, - out_ptr + out_offset_ * g)); - #endif - #if CUDNN_MAJOR == 3 - CUDNN_CALL(cudnnAddTensor(s->dnn_handle_, - CUDNN_ADD_SAME_C, - &alpha, - bias_desc_, - bias.dptr_ + bias_offset_ * g, - &beta_add, - out_desc_, - out_ptr + out_offset_ * g)); - #endif - } - } - #endif // CUDNN_MAJOR >= 7 } void Backward(const OpContext &ctx, @@ -256,7 +209,6 @@ class CuDNNConvolutionOp { CHECK_LE(back_workspace_byte_dgrad_, workspace_size); CHECK_LE(back_workspace_byte_wgrad_, workspace_size); } - #if CUDNN_MAJOR >= 7 typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; typename DataType::ScaleType beta_add = 1.0f; @@ -301,85 +253,6 @@ class CuDNNConvolutionOp { in_desc_, gdata_ptr)); } - #else - for (uint32_t g = 0; g < param_.num_group; ++g) { - typename DataType::ScaleType alpha = 1.0f; - typename DataType::ScaleType beta = 0.0f; - typename DataType::ScaleType beta_add = 1.0f; - if (!param_.no_bias && (req[conv::kBias] != kNullOp)) { - Tensor gbias = in_grad[conv::kBias].get(s); - CUDNN_CALL(cudnnConvolutionBackwardBias(s->dnn_handle_, - &alpha, - out_desc_, - grad_ptr + out_offset_ * g, - req[conv::kBias] == kAddTo ? &beta_add : &beta, - bias_desc_, - gbias.dptr_ + bias_offset_ * g)); - } - if (req[conv::kWeight] != kNullOp) { - #if CUDNN_MAJOR <= 4 - CUDNN_CALL(cudnnConvolutionBackwardFilter_v3(s->dnn_handle_, - &alpha, - in_desc_, - data_ptr + data_offset_ * g, - out_desc_, - grad_ptr + out_offset_ * g, - back_conv_desc_w_, - back_algo_w_.AlgoNumber(), - workspace.dptr_, - workspace_size, - req[conv::kWeight] == kAddTo? &beta_add : &beta, - filter_desc_, - gwmat_ptr + weight_offset_ * g)); - #elif CUDNN_MAJOR >= 5 - CUDNN_CALL(cudnnConvolutionBackwardFilter(s->dnn_handle_, - &alpha, - in_desc_, - data_ptr + data_offset_ * g, - out_desc_, - grad_ptr + out_offset_ * g, - back_conv_desc_w_, - back_algo_w_.AlgoNumber(), - workspace.dptr_, - workspace_size, - req[conv::kWeight] == kAddTo? &beta_add : &beta, - filter_desc_, - gwmat_ptr + weight_offset_ * g)); - #endif - } - if (req[conv::kData] != kNullOp) { - #if CUDNN_MAJOR <= 4 - CUDNN_CALL(cudnnConvolutionBackwardData_v3(s->dnn_handle_, - &alpha, - filter_desc_, - wmat_ptr + weight_offset_ * g, - out_desc_, - grad_ptr + out_offset_ * g, - back_conv_desc_, - back_algo_.AlgoNumber(), - workspace.dptr_, - workspace_size, - req[conv::kData] == kAddTo? &beta_add : &beta, - in_desc_, - gdata_ptr + data_offset_ * g)); - #elif CUDNN_MAJOR >= 5 - CUDNN_CALL(cudnnConvolutionBackwardData(s->dnn_handle_, - &alpha, - filter_desc_, - wmat_ptr + weight_offset_ * g, - out_desc_, - grad_ptr + out_offset_ * g, - back_conv_desc_, - back_algo_.AlgoNumber(), - workspace.dptr_, - workspace_size, - req[conv::kData] == kAddTo? &beta_add : &beta, - in_desc_, - gdata_ptr + data_offset_ * g)); - #endif - } - } - #endif // CUDNN_MAJOR >= 7 } /*! @@ -407,14 +280,7 @@ class CuDNNConvolutionOp { return false; } - // The factor by which the effective filter size grows based on dilation. - auto filterDilationFactor = param.dilate.Size(); - - // The v6 kernels that backprop a dilated convolution don't handle fp16. - // Dilation support across all architectures only available after v6.0.20. - return filterDilationFactor == 1 || - filterDilationFactor > 1 && (CUDNN_VERSION > 6020) && - (backward_compute_type != kFloat16); + return true; } private: @@ -443,17 +309,7 @@ class CuDNNConvolutionOp { mxnet::TShape wshape = in_shape[conv::kWeight]; mxnet::TShape oshape = out_shape[conv::kOut]; mxnet::TShape dstride, ostride; -#if CUDNN_MAJOR <= 6 - wshape[0] /= param_.num_group; -#endif -#if CUDNN_MAJOR <= 5 - // As of cuDNN_v6, the unsuffixed version of cudnnSetConvolution2dDescriptor() - // takes an additional 'computeType' parameter to set the precision of the - // convolution calculation. Supply this method signature for cuDNN versions < 6. -#define cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m, ct) \ - cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m) -#endif if (param_.kernel.ndim() == 1 || param_.kernel.ndim() == 2) { // 1d or 2d conv auto pad = param_.kernel.ndim() == 2 ? @@ -489,13 +345,6 @@ class CuDNNConvolutionOp { dilate[1], CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); -#if CUDNN_MAJOR < 5 - // As of cuDNN_v5, cudnnSetFilter4dDescriptor() takes a format parameter. - // Supply this method signature for cuDNN versions < 5. -#define cudnnSetFilter4dDescriptor(fdesc, dt, f, w0, w1, w2, w3) \ - cudnnSetFilter4dDescriptor(fdesc, dt, w0, w1, w2, w3) - CHECK_EQ(format_, CUDNN_TENSOR_NCHW) << "CuDNN V4 and earlier only supports NCHW layout"; -#endif if (param_.kernel.ndim() == 2) { wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); dstride = ConvertLayout(Strides<4>(dshape), param_.layout.value(), kNCHW); @@ -536,7 +385,6 @@ class CuDNNConvolutionOp { #endif } else if (param_.kernel.ndim() == 3) { // 3d conv - #if CUDNN_MAJOR >= 5 CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout"; std::vector wshape_buffer(wshape.ndim()); CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_, @@ -544,9 +392,6 @@ class CuDNNConvolutionOp { CUDNN_TENSOR_NCHW, static_cast(wshape.ndim()), CastTShapeToIntPtr(wshape, &wshape_buffer))); - #else - LOG(FATAL) << "Only support CUDNN V5 for 3D convolution"; - #endif CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_, 3, param_pad_.data(), @@ -577,29 +422,19 @@ class CuDNNConvolutionOp { oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW); } // Set "allow tensor core" flag in convolution descriptors, if available. - #if CUDNN_MAJOR >= 7 - cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH - : CUDNN_DEFAULT_MATH; - #if CUDNN_VERSION >= 7200 - if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && - (DataType::kFlag != kFloat16)) - math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; - #endif - CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type)); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type)); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type)); - CUDNN_CALL(cudnnSetConvolutionGroupCount(forward_conv_desc_, param_.num_group)); - CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_, param_.num_group)); - CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_w_, param_.num_group)); - #endif - - #if CUDNN_MAJOR <= 6 - dshape[1] /= param_.num_group; - oshape[1] /= param_.num_group; - #endif - weight_offset_ = wshape.Size(); - data_offset_ = dstride[1] * dshape[1]; - out_offset_ = ostride[1] * oshape[1]; + cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH + : CUDNN_DEFAULT_MATH; +#if CUDNN_VERSION >= 7200 + if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && + (DataType::kFlag != kFloat16)) + math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; +#endif + CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type)); + CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type)); + CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type)); + CUDNN_CALL(cudnnSetConvolutionGroupCount(forward_conv_desc_, param_.num_group)); + CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_, param_.num_group)); + CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_w_, param_.num_group)); std::vector dshape_buffer(dshape.ndim()); nnvm::ShapeTypeCast(dshape.begin(), dshape.end(), dshape_buffer.data()); @@ -624,18 +459,10 @@ class CuDNNConvolutionOp { if (!param_.no_bias) { mxnet::TShape bias = in_shape[conv::kBias]; - #if CUDNN_MAJOR >= 7 - bias_offset_ = bias[0]; std::vector bias_shape = {1, static_cast(bias[0]), 1, 1}; - #else - bias_offset_ = bias[0] / param_.num_group; - std::vector bias_shape = {1, - static_cast(bias[0] / param_.num_group), - 1, 1}; - #endif - std::vector bias_stride = {static_cast(bias_offset_), 1, 1, 1}; + std::vector bias_stride = {static_cast(bias[0]), 1, 1, 1}; if (param_.kernel.ndim() == 3) { bias_shape.push_back(1); bias_stride.push_back(1); @@ -660,12 +487,9 @@ class CuDNNConvolutionOp { mshadow::Stream *s = rctx.get_stream(); CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); -#if CUDNN_MAJOR >= 7 - // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire - // story: the notion of whether the algo ran in Tensor Core mode is not known. - // Since we want to report the Tensor Core mode in the verbose output, we switch - // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches - // that of *Find*(), we can unify the find-vs-get logic by using function pointers. + + // Since the function signature of *Get*_v7() matches that of *Find*(), + // we can unify the find-vs-get logic by using function pointers. // Forward Algorithm Find/Get() v7 std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); @@ -727,130 +551,6 @@ class CuDNNConvolutionOp { AlgoFinalSelect(bwd_data_results, "backprop-to-data", workspace_byte, bwd, exclude_dgrad_algo_); -#else - // CUDNN_MAJOR < 7 - const int kMaxAlgos = 10; - int nalgo = kMaxAlgos; - int i = 0; - size_t min_memory_needs = 0; - // Forward Algorithm Find/Get, v6 and earlier - if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { - // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is - // supported. Hard-coded this since the algo find() or get() throws an FPE. - fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); - } else if (!param_.cudnn_tune.value()) { - cudnnConvolutionFwdAlgo_t fastest_fwd_algo; - CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_fwd_algo)); - fwd->Set(fastest_fwd_algo, false); - } else { - cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - kMaxAlgos, - &nalgo, - fwd_algo)); - i = 0; - while (i < nalgo - && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() != conv::kFastest - && fwd_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = - (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, fwd_algo[i].memory); - } - if (i == nalgo) { - LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, "forward"); - } else { - fwd->Set(fwd_algo[i].algo, false); - } - } - // Backprop-to-Filter Algorithm Find/Get, v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_filt_algo)); - flt->Set(fastest_bwd_filt_algo, false); - } else { - cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - kMaxAlgos, - &nalgo, - bwd_filter_algo)); - i = 0; - while (i < nalgo - && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() != conv::kFastest - && bwd_filter_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_filter_algo[i].memory : - std::min(min_memory_needs, bwd_filter_algo[i].memory); - } - if (i == nalgo) { - LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, "backward filter"); - } else { - flt->Set(bwd_filter_algo[i].algo, false); - } - } - // Backprop-to-Data Algorithm Get(), v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_data_algo)); - bwd->Set(fastest_bwd_data_algo, false); - } else { - cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - kMaxAlgos, - &nalgo, - bwd_data_algo)); - i = 0; - while (i < nalgo - && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() != conv::kFastest - && bwd_data_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_data_algo[i].memory : - std::min(min_memory_needs, bwd_data_algo[i].memory); - } - if (i == nalgo) { - LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, "backward data"); - } else { - bwd->Set(bwd_data_algo[i].algo, false); - } - } -#endif // CUDNN_MAJOR < 7 // Fix for issue #11241 int cudnn_find_issue_max_features = 64 * 1024; @@ -911,11 +611,9 @@ class CuDNNConvolutionOp { // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest, // we must change the descriptor to preclude Tensor Core. Simplest is to // once again set the mathType in all cases. - #if CUDNN_MAJOR >= 7 CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, forward_algo_.MathType())); CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, back_algo_.MathType())); CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType())); - #endif } // Look over the results from *Find*() or *Get*() and pick the fastest algo given possible @@ -931,13 +629,9 @@ class CuDNNConvolutionOp { const auto &result = perf_results[i]; bool algo_exclusion = static_cast(result.algo) == algo_exclude; bool algo_is_tensor_core = false; - #if CUDNN_MAJOR >= 7 - algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; - #endif + algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; if (result.status == CUDNN_STATUS_SUCCESS && - #if CUDNN_MAJOR >= 7 (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && - #endif (param_.cudnn_tune.value() == conv::kLimited || result.memory <= workspace_byte) && !algo_exclusion) { algo->Set(result.algo, algo_is_tensor_core); @@ -1091,10 +785,6 @@ class CuDNNConvolutionOp { size_t back_workspace_byte_dgrad_; // Temp workspace size in bytes needed for Backward() wgrad (weight gradient) operation. size_t back_workspace_byte_wgrad_; - size_t data_offset_; - size_t out_offset_; - size_t weight_offset_; - size_t bias_offset_; cudnnDataType_t dtype_; cudnnTensorDescriptor_t in_desc_; cudnnTensorDescriptor_t out_desc_; diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index adb6caf1c028..ec7eec32b5b8 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -41,6 +41,8 @@ namespace op { template class CuDNNDeconvolutionOp { + STATIC_ASSERT_CUDNN_VERSION_GE(7000); + public: CuDNNDeconvolutionOp() { CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); @@ -71,7 +73,6 @@ class CuDNNDeconvolutionOp { // TensorCore algos only allowed on fp16-I/O deconvolutions if permitted by the global policy. cudnn_tensor_core_ = DataType::kFlag == kFloat16 && GetEnvAllowTensorCore(); -#if CUDNN_MAJOR >= 5 auto effective_layout = param_.layout.value(); switch (effective_layout) { // 1D convolutions will be executed as 2D convolutions with a height of 1. @@ -84,14 +85,9 @@ class CuDNNDeconvolutionOp { MSHADOW_LAYOUT_SWITCH(effective_layout, Layout, { format_ = LayoutType::kCudnnFlag; }); -#else - CHECK(param_.layout.value() == kNCW || - param_.layout.value() == kNCHW || - param_.layout.value() == kNCDHW) << "Need CuDNN > 5.0 for layout support"; -#endif // Double check to make sure this class supports the operation if (!Supports(param, forward_compute_type, backward_compute_type, rctx.ctx.dev_id)) - LOG(FATAL) << "Need CuDNN >= 6.0 for dilated deconvolution."; + LOG(FATAL) << "Deconvolution parameters not supported by cuDNN implementation."; InitDescriptors(in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); @@ -140,21 +136,6 @@ class CuDNNDeconvolutionOp { for (uint32_t g = 0; g < param_.num_group; ++g) { typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; - #if CUDNN_MAJOR <= 4 - CUDNN_CALL(cudnnConvolutionBackwardData_v3(s->dnn_handle_, - &alpha, - filter_desc_, - wmat_ptr + weight_offset_ * g, - in_desc_, - data_ptr + data_offset_ * g, - forward_conv_desc_, // this backward algorithm used for inference - back_algo_.AlgoNumber(), - workspace.dptr_, - workspace_size, - &beta, - out_desc_, - out.dptr_ + out_offset_ * g)); - #elif CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnConvolutionBackwardData(s->dnn_handle_, &alpha, filter_desc_, @@ -168,11 +149,9 @@ class CuDNNDeconvolutionOp { &beta, out_desc_, out_ptr + out_offset_ * g)); - #endif if (!param_.no_bias) { beta = 1.0f; Tensor bias = in_data[deconv::kBias].get(s); -#if CUDNN_MAJOR >= 4 CUDNN_CALL(cudnnAddTensor(s->dnn_handle_, &alpha, bias_desc_, @@ -180,17 +159,6 @@ class CuDNNDeconvolutionOp { &beta, out_desc_, out_ptr + out_offset_ * g)); -#endif -#if CUDNN_MAJOR == 3 - CUDNN_CALL(cudnnAddTensor(s->dnn_handle_, - CUDNN_ADD_SAME_C, - &alpha, - bias_desc_, - bias.dptr_ + bias_offset_ * g, - &beta, - out_desc_, - out_ptr + out_offset_ * g)); -#endif } } } @@ -244,23 +212,7 @@ class CuDNNDeconvolutionOp { gbias.dptr_ + bias_offset_ * g)); } if (req[deconv::kWeight] != kNullOp) { - #if CUDNN_MAJOR <= 4 - CUDNN_CALL(cudnnConvolutionBackwardFilter_v3( - s->dnn_handle_, - &alpha, - out_desc_, - grad_ptr + out_offset_ * g, - in_desc_, - data_ptr + data_offset_ * g, - back_conv_desc_, - back_algo_w_.AlgoNumber(), - workspace.dptr_, - workspace_size, - &weight_beta, - filter_desc_, - gwmat.dptr_ + weight_offset_ * g)); - #elif CUDNN_MAJOR >= 5 - CHECK_EQ(add_to_weight_, req[deconv::kWeight] == kAddTo); + CHECK_EQ(add_to_weight_, req[deconv::kWeight] == kAddTo); CUDNN_CALL(cudnnConvolutionBackwardFilter( s->dnn_handle_, &alpha, @@ -275,7 +227,6 @@ class CuDNNDeconvolutionOp { &weight_beta, filter_desc_, gwmat_ptr + weight_offset_ * g)); - #endif } if (req[deconv::kData] != kNullOp) { CUDNN_CALL(cudnnConvolutionForward(s->dnn_handle_, @@ -323,16 +274,7 @@ class CuDNNDeconvolutionOp { // The factor by which the effective filter size grows based on dilation. auto filterDilationFactor = param.dilate.Size(); - // The v6 kernels that backprop a dilated convolution don't handle fp16. - // Since the deconvolution "forward" kernel is really a backprop-to-data - // cuDNN kernel, the following logic is slightly different than that - // used in CuDNNConvolution::Supports(). - - // Dilation support across all architectures only available after v6.0.20. - return filterDilationFactor == 1 || - filterDilationFactor > 1 && (CUDNN_VERSION > 6020) && - (backward_compute_type != kFloat16) && - (forward_compute_type != kFloat16); + return true; } private: @@ -362,13 +304,6 @@ class CuDNNDeconvolutionOp { mxnet::TShape oshape = out_shape[deconv::kOut]; mxnet::TShape dstride, ostride; wshape[0] /= param_.num_group; -#if CUDNN_MAJOR <= 5 - // As of cuDNN_v6, the unsuffixed version of cudnnSetConvolution2dDescriptor() - // takes an additional 'computeType' parameter to set the precision of the - // convolution calculation. Supply this method signature for cuDNN versions < 6. -#define cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m, ct) \ - cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m) -#endif if (param_.kernel.ndim() == 1 || param_.kernel.ndim() == 2) { // 1d or 2d conv index_t o_pad[2]; @@ -414,13 +349,6 @@ class CuDNNDeconvolutionOp { dilate[1], CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); -#if CUDNN_MAJOR < 5 - // As of cuDNN_v5, cudnnSetFilter4dDescriptor() takes a format parameter. - // Supply this method signature for cuDNN versions < 5. -#define cudnnSetFilter4dDescriptor(fdesc, dt, f, w0, w1, w2, w3) \ - cudnnSetFilter4dDescriptor(fdesc, dt, w0, w1, w2, w3) - CHECK_EQ(format_, CUDNN_TENSOR_NCHW) << "CuDNN V4 and earlier only supports NCHW layout"; -#endif if (param_.kernel.ndim() == 2) { wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); dstride = ConvertLayout(Strides<4>(dshape), param_.layout.value(), kNCHW); @@ -465,7 +393,6 @@ class CuDNNDeconvolutionOp { index_t o_adj[3]; param_.InferPad(dshape, o_pad, o_adj); - #if CUDNN_MAJOR >= 5 CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout"; std::vector wshape_buffer(wshape.ndim()); CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_, @@ -473,9 +400,6 @@ class CuDNNDeconvolutionOp { CUDNN_TENSOR_NCHW, static_cast(wshape.ndim()), CastTShapeToIntPtr(wshape, &wshape_buffer))); - #else - LOG(FATAL) << "Only support CUDNN V5 for 3D convolution"; - #endif CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_, 3, reinterpret_cast(&o_pad[0]), @@ -506,13 +430,11 @@ class CuDNNDeconvolutionOp { oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW); } // Set "allow tensor core" flag in convolution descriptors, if available. -#if CUDNN_MAJOR >= 7 cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH - : CUDNN_DEFAULT_MATH; + : CUDNN_DEFAULT_MATH; CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type)); CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type)); CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type)); -#endif dshape[1] /= param_.num_group; oshape[1] /= param_.num_group; weight_offset_ = wshape.Size(); @@ -566,12 +488,9 @@ class CuDNNDeconvolutionOp { mshadow::Stream *s = rctx.get_stream(); CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); -#if CUDNN_MAJOR >= 7 - // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire - // story: the notion of whether the algo ran in Tensor Core mode is not known. - // Since we want to report the Tensor Core mode in the verbose output, we switch - // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches - // that of *Find*(), we can unify the find-vs-get logic by using function pointers. + + // Since the function signature of *Get*_v7() matches that of *Find*(), + // we can unify the find-vs-get logic by using function pointers. // Forward Algorithm Find/Get() v7 std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); @@ -632,134 +551,6 @@ class CuDNNDeconvolutionOp { AlgoFinalSelect(bwd_data_results, "backprop-to-data", workspace_byte, bwd, exclude_dgrad_algo_); -#else - // CUDNN_MAJOR < 7 - const int kMaxAlgos = 10; - int nalgo = kMaxAlgos; - int i = 0; - size_t min_memory_needs = 0; - // Forward Algorithm Find/Get, v6 and earlier - if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { - // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is - // supported. Hard-coded this since the algo find() or get() throws an FPE. - fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); - } else if (!param_.cudnn_tune.value()) { - cudnnConvolutionFwdAlgo_t fastest_fwd_algo; - CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used in dgrad - in_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_fwd_algo)); - fwd->Set(fastest_fwd_algo, false); - } else { - cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used in dgrad - in_desc_, - kMaxAlgos, - &nalgo, - fwd_algo)); - i = 0; - while (i < nalgo - && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && fwd_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - fwd_algo[i].memory : - std::min(min_memory_needs, fwd_algo[i].memory); - } - if (i == nalgo) { - LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, - "forward algos (for use in deconv op backprop-to-data)"); - } else { - fwd->Set(fwd_algo[i].algo, false); - } - } - // Backprop-to-Filter Algorithm Find/Get, v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - back_conv_desc_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_filt_algo)); - flt->Set(fastest_bwd_filt_algo, false); - } else { - cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - back_conv_desc_, - filter_desc_, - kMaxAlgos, - &nalgo, - bwd_filter_algo)); - i = 0; - while (i < nalgo - && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && bwd_filter_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_filter_algo[i].memory : - std::min(min_memory_needs, bwd_filter_algo[i].memory); - } - if (i == nalgo) { - LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, - "backward filter algos (for use in deconv op backprop-to-filter)"); - } else { - flt->Set(bwd_filter_algo[i].algo, false); - } - } - // Backprop-to-Data Algorithm Get(), v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // bwd algo used for inference - out_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_data_algo)); - bwd->Set(fastest_bwd_data_algo, false); - } else { - cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // bwd algo used in inference - out_desc_, - kMaxAlgos, - &nalgo, - bwd_data_algo)); - i = 0; - while (i < nalgo - && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && bwd_data_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_data_algo[i].memory : - std::min(min_memory_needs, bwd_data_algo[i].memory); - } - if (i == nalgo) { - LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, - "backward data algos (for use in deconv op forward inference)"); - } else { - bwd->Set(bwd_data_algo[i].algo, false); - } - } -#endif // CUDNN_MAJOR < 7 // Fix for issue #11241 int cudnn_find_issue_max_features = 64 * 1024; @@ -794,9 +585,9 @@ class CuDNNDeconvolutionOp { // DirectFree(), which makes these areas available for cudnn's subsequent cudaMalloc(). // Allocate for x (or dx), w (or dw) and y (or dy). - ReserveElements({in_shape[conv::kData].Size(), - in_shape[conv::kWeight].Size(), - out_shape[conv::kOut].Size()}); + ReserveElements({in_shape[deconv::kData].Size(), + in_shape[deconv::kWeight].Size(), + out_shape[deconv::kOut].Size()}); // We're about to call cudnnFind so we need to quiet the system by grabbing // the Storage lock. Concurrent cudaMalloc's can disrupt the accurate timing @@ -825,7 +616,7 @@ class CuDNNDeconvolutionOp { // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest, // we must change the descriptor to preclude Tensor Core. Simplest is to // once again set the mathType in all cases. - #if CUDNN_MAJOR >= 7 + // The next two code lines will look like they have typos, but they don't! // The forward_conv_desc_ is used during inference, which invokes the back_algo_. // Thus, the mathType of the back_algo_ should be stored in the forward_conv_desc_. @@ -835,7 +626,6 @@ class CuDNNDeconvolutionOp { CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, back_algo_.MathType())); CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, forward_algo_.MathType())); CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType())); - #endif } // Look over the results from *Find*() or *Get*() and pick the fastest algo given possible @@ -850,20 +640,16 @@ class CuDNNDeconvolutionOp { const auto &result = perf_results[i]; bool algo_exclusion = static_cast(result.algo) == algo_exclude; bool algo_is_tensor_core = false; - #if CUDNN_MAJOR >= 7 - algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; - #endif + algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; if (result.status == CUDNN_STATUS_SUCCESS && - #if CUDNN_MAJOR >= 7 (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && - #endif - (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte) && + (param_.cudnn_tune.value() != deconv::kLimited || result.memory <= workspace_byte) && !algo_exclusion) { algo->Set(result.algo, algo_is_tensor_core); return; } } - auto mode = param_.cudnn_tune.value() == conv::kOff ? " get " : " find "; + auto mode = param_.cudnn_tune.value() == deconv::kOff ? " get " : " find "; LOG(FATAL) << "Failed to" << mode << "any " << kernel_name << " deconvolution algorithm" << " with workspace size of " << workspace_byte << " bytes," << " please consider reducing batch/model size or increasing the workspace size"; diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h index ada605db0ee9..f52848b4a452 100644 --- a/src/operator/nn/cudnn/cudnn_pooling-inl.h +++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h @@ -35,6 +35,8 @@ namespace op { template class CuDNNPoolingOp { + STATIC_ASSERT_CUDNN_VERSION_GE(7000); + public: CuDNNPoolingOp() { // TODO(xxx): fp16 @@ -48,12 +50,8 @@ class CuDNNPoolingOp { param_ = p; switch (param_.pool_type) { case pool_enum::kMaxPooling: - #if CUDNN_MAJOR >= 7 mode_ = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false) ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; - #else - mode_ = CUDNN_POOLING_MAX; - #endif break; case pool_enum::kAvgPooling: if (param_.count_include_pad.has_value() && !param_.count_include_pad.value()) { @@ -229,10 +227,6 @@ class CuDNNPoolingOp { #endif } else if (param.kernel.ndim() == 3) { // 3d pooling -#if CUDNN_MAJOR < 5 - LogUnsupportedDim(&unsupported_dim_warning_issued, param.kernel.ndim()); - return false; -#endif if (!(layout == mshadow::kNCDHW || layout == mshadow::kNDHWC)) return false; } else { @@ -250,9 +244,7 @@ class CuDNNPoolingOp { const TBlob &out_data) { using namespace mshadow; bool is_supported = true; - #if CUDNN_MAJOR >= 5 nan_prop_ = CUDNN_NOT_PROPAGATE_NAN; - #endif int layout = param_.GetLayout(in_data.ndim()); if (param_.kernel.ndim() == 2) { // 2d pooling @@ -290,7 +282,6 @@ class CuDNNPoolingOp { #if CUDNN_VERSION == 7104 is_supported = kernel_height <= 8 && kernel_width <= 8; #endif - #if CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_, mode_, nan_prop_, @@ -300,16 +291,6 @@ class CuDNNPoolingOp { param_.global_pool ? 0 : param_.pad[1], param_.global_pool ? 1 : param_.stride[0], param_.global_pool ? 1 : param_.stride[1])); - #else - CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_, - mode_, - kernel_height, - kernel_width, - param_.global_pool ? 0 : param_.pad[0], - param_.global_pool ? 0 : param_.pad[1], - param_.global_pool ? 1 : param_.stride[0], - param_.global_pool ? 1 : param_.stride[1])); - #endif } else { CHECK(layout == mshadow::kNCDHW || layout == mshadow::kNDHWC) << "Need 3D layout NCDHW or NDHWC."; @@ -376,7 +357,6 @@ class CuDNNPoolingOp { static_cast(oshape_ncdhw_int.size()), &oshape_ncdhw_int[0], &ostride_ncdhw_int[0])); - #if CUDNN_MAJOR >= 5 CUDNN_CALL(cudnnSetPoolingNdDescriptor(pooling_desc_, mode_, nan_prop_, @@ -384,9 +364,6 @@ class CuDNNPoolingOp { &(kernel_vec[0]), &(pad_vec[0]), &(stride_vec[0]))); - #else - LOG(FATAL) << "3D pooling is only supported by CUDNN v5 and above."; - #endif } return is_supported; } @@ -406,9 +383,7 @@ class CuDNNPoolingOp { cudnnTensorDescriptor_t in_desc_; cudnnTensorDescriptor_t out_desc_; cudnnPoolingDescriptor_t pooling_desc_; - #if CUDNN_MAJOR >= 5 cudnnNanPropagation_t nan_prop_; - #endif PoolingParam param_; }; // class CuDNNPoolingOp } // namespace op diff --git a/src/operator/quantization/quantized_conv.cu b/src/operator/quantization/quantized_conv.cu index 23c41a17ef4a..28bd43239066 100644 --- a/src/operator/quantization/quantized_conv.cu +++ b/src/operator/quantization/quantized_conv.cu @@ -49,7 +49,8 @@ struct QuantizedBiasAddKernel { } }; -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#if MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 +STATIC_ASSERT_CUDNN_VERSION_GE(6000); template class QuantizedCuDNNConvOp { public: @@ -260,7 +261,7 @@ class QuantizedCuDNNConvOp { float alpha_ = 1.0f; float beta_ = 0.0f; }; // class QuantizedCuDNNConvOp -#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#endif // MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -270,7 +271,7 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param = nnvm::get(attrs.parsed); CHECK_EQ(param.kernel.ndim(), 2U) << "QuantizedConvForward only supports 2D convolution for now"; -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#if MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 typedef QuantizedCuDNNConvOp QuantizedConvOpInt8; #if DMLC_CXX11_THREAD_LOCAL static thread_local QuantizedConvOpInt8 op; @@ -282,7 +283,7 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs, #else LOG(FATAL) << "QuantizedConvForward only supports cudnnConvolutionForward " "with CUDNN >= 6.0 and CUDA >= 8.0"; -#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#endif // MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 } NNVM_REGISTER_OP(_contrib_quantized_conv) diff --git a/src/operator/quantization/quantized_pooling.cu b/src/operator/quantization/quantized_pooling.cu index a8fba87090ab..167c08e99683 100644 --- a/src/operator/quantization/quantized_pooling.cu +++ b/src/operator/quantization/quantized_pooling.cu @@ -29,7 +29,8 @@ namespace mxnet { namespace op { -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#if MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 +STATIC_ASSERT_CUDNN_VERSION_GE(6000); template class QuantizedCuDNNPoolingOp { public: @@ -115,7 +116,7 @@ class QuantizedCuDNNPoolingOp { cudnnTensorDescriptor_t out_desc_; cudnnPoolingDescriptor_t pool_desc_; }; // class QuantizedCuDNNPoolingOp -#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#endif // MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -125,7 +126,7 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs, const PoolingParam& param = nnvm::get(attrs.parsed); CHECK_EQ(param.kernel.ndim(), 2U) << "QuantizedPoolingForward only supports 2D convolution for now"; -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#if MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 #if DMLC_CXX11_THREAD_LOCAL static thread_local QuantizedCuDNNPoolingOp op; #else @@ -136,7 +137,7 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs, #else LOG(FATAL) << "QuantizedPoolingForward only supports cudnnPoolingForward " "with CUDNN >= 6.0 and CUDA >= 8.0"; -#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000 +#endif // MXNET_USE_CUDNN == 1 && CUDA_VERSION >= 8000 } NNVM_REGISTER_OP(_contrib_quantized_pooling) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 328e28de8537..5eae413b078b 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -26,7 +26,9 @@ #ifndef MXNET_OPERATOR_RNN_INL_H_ #define MXNET_OPERATOR_RNN_INL_H_ -#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 +STATIC_ASSERT_CUDNN_VERSION_GE(7000); +#endif #define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 #include @@ -396,7 +398,7 @@ class RNNOp { public: RNNParam param_; Context ctx_; - #if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 std::vector concat_weight_memory; std::vector concat_iter_memory; std::vector rnn_forward_prim; @@ -411,15 +413,15 @@ class RNNOp { bool init_mem_; size_t reserve_mem_size_; Storage::Handle mem_space_; - #endif +#endif explicit RNNOp(RNNParam param, Context ctx) { this->param_ = param; this->ctx_ = ctx; - #if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 init_mem_ = false; reserve_mem_size_ = 0; - #endif - #if MXNET_USE_CUDNN_RNN +#endif +#if MXNET_USE_CUDNN == 1 init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. @@ -457,7 +459,7 @@ class RNNOp { #else CHECK(!param_.projection_size.has_value()) << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; -#endif +#endif // MXNET_USE_CUDNN_GE_7200 #if MXNET_USE_CUDNN_GE_7200 if (param_.lstm_state_clip_min.has_value() || param_.lstm_state_clip_max.has_value()) { @@ -472,7 +474,7 @@ class RNNOp { CHECK(!param_.lstm_state_clip_min.has_value() && !param_.lstm_state_clip_max.has_value()) << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; -#endif +#endif // MXNET_USE_CUDNN_GE_7200 // RNN Direction direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; // Create descriptors @@ -491,17 +493,17 @@ class RNNOp { CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); - #endif - #else +#endif +#else if (ctx_.dev_type == kGPU) { LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; } - #endif +#endif // MXNET_USE_CUDNN == 1 if (ctx_.dev_type == kCPU) { this->init_space_ = false; @@ -520,13 +522,13 @@ class RNNOp { } ~RNNOp() { - #if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 if (init_mem_) { Storage::Get()->Free(mem_space_); init_mem_ = false; } - #endif - #if MXNET_USE_CUDNN_RNN +#endif // MXNET_USE_MKLDNN +#if MXNET_USE_CUDNN == 1 CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_)); @@ -551,13 +553,13 @@ class RNNOp { init_cudnn_ = false; Storage::Get()->Free(reserve_space_); } - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_)); CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_)); CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_)); CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_)); - #endif - #endif +#endif // MXNET_USE_CUDNN_GE_7200 +#endif // MXNET_USE_CUDNN if (ctx_.dev_type == kCPU) { if (init_space_) { @@ -671,7 +673,7 @@ class RNNOp { CHECK_EQ(hx.CheckContiguous(), true); CHECK_EQ(y.CheckContiguous(), true); -#if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (!init_cudnn_) { Init(ctx, s, in_data, out_data); } @@ -736,9 +738,7 @@ class RNNOp { sequence_length_cpu_int, reinterpret_cast(&padding_fill_))); } -#endif -#if MXNET_USE_CUDNN_GE_7200 bool clip_state = param_.lstm_state_clip_min.has_value(); bool clip_nan = param_.lstm_state_clip_nan; CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_, @@ -747,7 +747,7 @@ class RNNOp { clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN, clip_state ? param_.lstm_state_clip_min.value() : 0.0, clip_state ? param_.lstm_state_clip_max.value() : 0.0)); -#endif +#endif // MXNET_USE_CUDNN_GE_7200 if (ctx.is_train) { #if MXNET_USE_CUDNN_GE_7200 @@ -801,7 +801,7 @@ class RNNOp { workspace_byte_, reserve_space_.dptr, reserve_space_byte_)); -#endif +#endif // MXNET_USE_CUDNN_GE_7200 } else { #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_, @@ -850,9 +850,9 @@ class RNNOp { cy_ptr, temp_space.dptr_, workspace_byte_)); -#endif +#endif // MXNET_USE_CUDNN_GE_7200 } -#endif +#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (ctx_.dev_type == kCPU) { if (ctx.is_train) { @@ -907,7 +907,7 @@ class RNNOp { param_.p, param_.mode); } else { - #if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one // will be added to MXNet when we figure out the issue. @@ -942,7 +942,7 @@ class RNNOp { ctx.is_train, param_.mode); } else { - #endif +#endif // MXNET_USE_MKLDNN == 1 // Before integrating MKLDNN GRU fp32 inference // using below code for keep func being OK const size_t work_cpu_space_size = @@ -976,9 +976,9 @@ class RNNOp { hy_ptr, cy_ptr, param_.mode); - #if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 } - #endif +#endif } } } @@ -1061,7 +1061,7 @@ class RNNOp { dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get(s)).dptr_; } - #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (!init_cudnn_) { Init(ctx, s, in_data, out_data); } @@ -1072,7 +1072,7 @@ class RNNOp { ctx.requested[rnn_enum::kTempSpace].get_space_typed( mshadow::Shape1(temp_size), s); - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, rnn_desc_, y_data_desc_, @@ -1117,7 +1117,7 @@ class RNNOp { dw.dptr_, reserve_space_.dptr, reserve_space_byte_)); - #else +#else CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_, rnn_desc_, param_.seq_length_, @@ -1160,8 +1160,8 @@ class RNNOp { dw.dptr_, reserve_space_.dptr, reserve_space_byte_)); - #endif - #endif +#endif // MXNET_USE_CUDNN_GE_7200 +#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) if (ctx_.dev_type == kCPU) { // allocate temp space @@ -1230,10 +1230,8 @@ class RNNOp { CHECK_EQ(in_data.size(), num_inputs); CHECK_EQ(out_data.size(), num_outputs); - #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__) - #if CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) format_ = CUDNN_TENSOR_NCHW; - #endif if (!init_cudnn_) { init_cudnn_ = true; @@ -1304,7 +1302,7 @@ class RNNOp { strideA[0] = dimA[2] * dimA[1]; strideA[1] = dimA[2]; strideA[2] = 1; - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 int dimB[3]; int strideB[3]; dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1); @@ -1314,74 +1312,74 @@ class RNNOp { strideB[0] = dimB[2] * dimB[1]; strideB[1] = dimB[2]; strideB[2] = 1; - #endif - #if MXNET_USE_CUDNN_GE_7200 +#endif // MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dimB, strideB)); - #else +#else CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dimA, strideA)); - #endif +#endif // MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_, dtype_, 3, dimA, strideA)); - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dimB, strideB)); - #else +#else CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dimA, strideA)); - #endif +#endif // MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dimA, strideA)); - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dimB, strideB)); - #else +#else CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dimA, strideA)); - #endif +#endif // MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dimA, strideA)); - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dimB, strideB)); - #else +#else CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dimA, strideA)); - #endif +#endif // MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, @@ -1403,20 +1401,19 @@ class RNNOp { // RNN descriptors cudnnDataType_t dtype_with_fallback_; - #if CUDNN_MAJOR >= 6 cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD; // On arch's 50 and 52(Maxwell), the gpu doesn't support native fp16 compute. // Before cuDNN 7.5.0, when running fp16, cuDNN fallback to fp32 under the hood on Maxwell. // That's not the case begining from 7.5.0. Thereby adding fallback explicitly here. - #if __CUDA_ARCH__ < 530 && CUDNN_MAJOR >=7 && CUDNN_MINOR >= 5 +#if __CUDA_ARCH__ < 530 && CUDNN_VERSION >= 7500 if (dtype_ == CUDNN_DATA_HALF) { dtype_with_fallback_ = CUDNN_DATA_FLOAT; } else { dtype_with_fallback_ = dtype_; } - #else +#else dtype_with_fallback_ = dtype_; - #endif +#endif CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_, rnn_desc_, param_.state_size, @@ -1427,45 +1424,30 @@ class RNNOp { mode_, rnn_algo, dtype_with_fallback_)); - #else - CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_, - param_.state_size, - param_.num_layers, - dropout_desc_, - input_mode_, - direction_, - mode_, - dtype_)); - #endif - #if CUDNN_MAJOR >= 7 - cudnnMathType_t math_type = CUDNN_DEFAULT_MATH; - if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) { - math_type = CUDNN_TENSOR_OP_MATH; - } - #if CUDNN_VERSION >= 7200 - if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && - (DataType::kFlag != kFloat16)) { - math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; - } - #endif - CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); - #endif - #if MXNET_USE_CUDNN_GE_7200 + cudnnMathType_t math_type = CUDNN_DEFAULT_MATH; + if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) { + math_type = CUDNN_TENSOR_OP_MATH; + } +#if CUDNN_VERSION >= 7200 + if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && + (DataType::kFlag != kFloat16)) { + math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; + } +#endif + CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type)); +#if MXNET_USE_CUDNN_GE_7200 if (param_.projection_size.has_value()) { CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_, rnn_desc_, param_.projection_size.value(), 0)); } - #endif - // Get temp space sizes - - #if MXNET_USE_CUDNN_GE_7200 if (param_.use_sequence_length) { CUDNN_CALL(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED)); } - #endif +#endif // MXNET_USE_CUDNN_GE_7200 + // Get temp space sizes CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_, rnn_desc_, param_.seq_length_, @@ -1537,9 +1519,9 @@ class RNNOp { // } // } } - #endif +#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) } - #if MXNET_USE_CUDNN_RNN +#if MXNET_USE_CUDNN == 1 cudnnDataType_t dtype_; bool init_cudnn_; cudnnRNNDescriptor_t rnn_desc_; @@ -1552,10 +1534,10 @@ class RNNOp { size_t workspace_byte_, reserve_space_byte_; int workspace_size_; std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; - #if MXNET_USE_CUDNN_GE_7200 +#if MXNET_USE_CUDNN_GE_7200 cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; DType padding_fill_ = 0; - #endif +#endif // MXNET_USE_CUDNN_GE_7200 cudnnTensorDescriptor_t hx_desc_, cx_desc_; cudnnTensorDescriptor_t hy_desc_, cy_desc_; cudnnTensorDescriptor_t dhx_desc_, dcx_desc_; @@ -1565,10 +1547,8 @@ class RNNOp { // Allow TensorCore algo policy bool cudnn_tensor_core_; - #if CUDNN_MAJOR >= 5 cudnnTensorFormat_t format_; - #endif - #endif +#endif // MXNET_USE_CUDNN bool init_space_, temp_init_space_; size_t reserve_cpu_space_size_, temp_cpu_space_size_; Storage::Handle reserve_cpu_space_, temp_cpu_space_; diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 244e39335a91..86fb1c7d1ec6 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -171,7 +171,7 @@ static std::vector RNNResourceEx(const NodeAttrs& attrs, const const DispatchMode dispatch_mode) { std::vector request; if (dev_mask == kGPU) { -#if MXNET_USE_CUDNN_RNN +#if MXNET_USE_CUDNN == 1 request.emplace_back(ResourceRequest::kTempSpace); const RNNParam& param = nnvm::get(attrs.parsed); diff --git a/src/operator/spatial_transformer-inl.h b/src/operator/spatial_transformer-inl.h index 660d57d55bab..1a684a899d85 100644 --- a/src/operator/spatial_transformer-inl.h +++ b/src/operator/spatial_transformer-inl.h @@ -267,7 +267,7 @@ class SpatialTransformerProp : public OperatorProperty { return {ResourceRequest::kTempSpace}; } - #if CUDNN_MAJOR >= 5 + #if MXNET_USE_CUDNN == 1 std::vector BackwardResource( const mxnet::ShapeVector &in_shape) const override { return {ResourceRequest::kTempSpace}; diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu index fd330bd4ca87..4067714d426d 100644 --- a/src/operator/spatial_transformer.cu +++ b/src/operator/spatial_transformer.cu @@ -26,9 +26,9 @@ #include "./spatial_transformer-inl.h" #include -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 #include "./cudnn_spatial_transformer-inl.h" -#endif // MXNET_USE_CUDNN && CUDNN_MAJOR +#endif // MXNET_USE_CUDNN namespace mshadow { template @@ -214,7 +214,7 @@ namespace op { template<> Operator* CreateOp(SpatialTransformerParam param, int dtype) { Operator *op = NULL; -#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#if MXNET_USE_CUDNN == 1 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { if (param.cudnn_off.has_value() && param.cudnn_off.value()) { op = new SpatialTransformerOp(param); @@ -226,7 +226,7 @@ Operator* CreateOp(SpatialTransformerParam param, int dtype) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new SpatialTransformerOp(param); }) -#endif // MXNET_USE_CUDNN && CUDNN_MAJOR +#endif // MXNET_USE_CUDNN return op; }