Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

cuDNN support cleanup #15812

Merged
merged 3 commits into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <dmlc/parameter.h>
#include <dmlc/optional.h>
#include <mshadow/base.h>
#include <mxnet/libinfo.h>

/*! \brief Macros/inlines to assist CLion to parse Cuda files (*.cu, *.cuh) */
#ifdef __JETBRAINS_IDE__
Expand Down Expand Up @@ -482,13 +483,10 @@ static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10,
* want to populate.
*/
inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
#if CUDNN_MAJOR >= 7
STATIC_ASSERT_CUDNN_VERSION_GE(7000);
int max_algos = 0;
CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
return max_algos;
#else
return 10;
#endif
}

/*!
Expand All @@ -499,13 +497,10 @@ inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
* want to populate.
*/
inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
#if CUDNN_MAJOR >= 7
STATIC_ASSERT_CUDNN_VERSION_GE(7000);
int max_algos = 0;
CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
return max_algos;
#else
return 10;
#endif
}

/*!
Expand All @@ -516,13 +511,10 @@ inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
* want to populate.
*/
inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
#if CUDNN_MAJOR >= 7
STATIC_ASSERT_CUDNN_VERSION_GE(7000);
int max_algos = 0;
CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
return max_algos;
#else
return 10;
#endif
}

#endif // MXNET_USE_CUDNN
Expand Down
4 changes: 2 additions & 2 deletions src/executor/attach_op_resource_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ void AttachOpResources(
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
#if MXNET_USE_CUDNN == 1
case ResourceRequest::kCuDNNDropoutDesc: {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
break;
}
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
#endif // MXNET_USE_CUDNN == 1
default:
LOG(FATAL) << "resource type " << req.type << " is not yet supported";
}
Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs,
requested.push_back(ResourceManager::Get()->Request(ctx, req));
write_vars.push_back(requested.back().var);
break;
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
#if MXNET_USE_CUDNN == 1
case ResourceRequest::kCuDNNDropoutDesc:
requested.push_back(ResourceManager::Get()->Request(ctx, req));
write_vars.push_back(requested.back().var);
break;
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
#endif // MXNET_USE_CUDNN == 1
default:
LOG(FATAL) << "resource type not yet supported";
}
Expand Down
8 changes: 4 additions & 4 deletions src/operator/bilinear_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
#include "./bilinear_sampler-inl.h"
#include <algorithm>
#include "../common/cuda_utils.h"
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#if MXNET_USE_CUDNN == 1
#include "./cudnn_bilinear_sampler-inl.h"
#endif // MXNET_USE_CUDNN && CUDNN_MAJOR
#endif // MXNET_USE_CUDNN

namespace mshadow {
namespace cuda {
Expand Down Expand Up @@ -228,7 +228,7 @@ namespace op {
template<>
Operator* CreateOp<gpu>(BilinearSamplerParam param, int dtype) {
Operator *op = NULL;
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#if MXNET_USE_CUDNN == 1
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
if (param.cudnn_off.has_value() && param.cudnn_off.value()) {
op = new BilinearSamplerOp<gpu, DType>(param);
Expand All @@ -240,7 +240,7 @@ Operator* CreateOp<gpu>(BilinearSamplerParam param, int dtype) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new BilinearSamplerOp<gpu, DType>(param);
})
#endif // MXNET_USE_CUDNN && CUDNN_MAJOR
#endif // MXNET_USE_CUDNN
return op;
}

Expand Down
9 changes: 3 additions & 6 deletions src/operator/cudnn_bilinear_sampler-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
#include "./bilinear_sampler-inl.h"
namespace mxnet {
namespace op {
#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1
STATIC_ASSERT_CUDNN_VERSION_GE(5000);
template<typename DType>
class CuDNNBilinearSamplerOp : public Operator {
public:
Expand Down Expand Up @@ -132,9 +133,7 @@ class CuDNNBilinearSamplerOp : public Operator {
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data) {
using namespace mshadow;
#if CUDNN_MAJOR >= 5
format_ = CUDNN_TENSOR_NCHW;
#endif
CHECK_EQ(in_data.size(), 2U);
CHECK_EQ(out_data.size(), 2U);
if (!init_cudnn_) {
Expand Down Expand Up @@ -174,12 +173,10 @@ class CuDNNBilinearSamplerOp : public Operator {
cudnnTensorDescriptor_t in_desc_;
cudnnTensorDescriptor_t out_desc_;
cudnnSamplerType_t sampler_;
#if CUDNN_MAJOR >= 5
cudnnTensorFormat_t format_;
#endif
BilinearSamplerParam param_;
};
#endif // __CUDACC__ && CUDNN
#endif // __CUDACC__ && MXNET_USE_CUDNN
} // namespace op
} // namespace mxnet

Expand Down
9 changes: 3 additions & 6 deletions src/operator/cudnn_spatial_transformer-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
#include "./spatial_transformer-inl.h"
namespace mxnet {
namespace op {
#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1
STATIC_ASSERT_CUDNN_VERSION_GE(5000);
template<typename DType>
class CuDNNSpatialTransformerOp : public Operator {
public:
Expand Down Expand Up @@ -145,9 +146,7 @@ class CuDNNSpatialTransformerOp : public Operator {
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data) {
using namespace mshadow;
#if CUDNN_MAJOR >= 5
format_ = CUDNN_TENSOR_NCHW;
#endif
CHECK_EQ(in_data.size(), 2U);
CHECK_EQ(out_data.size(), 3U);
if (!init_cudnn_) {
Expand Down Expand Up @@ -189,12 +188,10 @@ class CuDNNSpatialTransformerOp : public Operator {
cudnnTensorDescriptor_t in_desc_;
cudnnTensorDescriptor_t out_desc_;
cudnnSamplerType_t sampler_;
#if CUDNN_MAJOR >= 5
cudnnTensorFormat_t format_;
#endif
SpatialTransformerParam param_;
};
#endif // __CUDACC__ && CUDNN
#endif // __CUDACC__ && MXNET_USE_CUDNN
} // namespace op
} // namespace mxnet

Expand Down
8 changes: 4 additions & 4 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm-inl.h"
#endif

Expand Down Expand Up @@ -641,7 +641,7 @@ void BatchNormBackwardImpl(mshadow::Stream<gpu> *stream,
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
}

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4
#if MXNET_USE_CUDNN == 1
template<typename DType>
static CuDNNBatchNormOp<DType> &GetCuDNNOp(const BatchNormParam& param) {
#if DMLC_CXX11_THREAD_LOCAL
Expand All @@ -667,7 +667,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
mxnet::TShape shape = inputs[0].shape_;

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
Expand Down Expand Up @@ -696,7 +696,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
mxnet::TShape shape = inputs[0].shape_;

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
Expand Down
99 changes: 35 additions & 64 deletions src/operator/nn/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,39 +94,8 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
int dtype = inputs[conv::kData].type_flag_;

#if CUDNN_MAJOR < 5
if (param.layout.value() != kNCW &&
param.layout.value() != kNCHW &&
param.layout.value() != kNCDHW) {
// Need CuDNN > 5.0 for layout support. use MXNet implementation
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
ConvolutionOp<gpu, DType> op;
op.Init(param);
op.Forward(ctx, inputs, req, outputs);
})
return;
}
#endif

#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7
if (param.num_filter == param.num_group &&
param.layout.value() == mshadow::kNCHW &&
param.num_filter == inputs[conv::kData].shape_[1] &&
param.kernel.ndim() == 2 &&
param.dilate == mshadow::Shape2(1, 1) &&
dtype == mshadow::kFloat32) {
mxnet::ShapeVector in_shape(inputs.size());
mxnet::ShapeVector out_shape(1, outputs[0].shape_);
for (size_t i = 0; i < in_shape.size(); i++)
in_shape[i] = inputs[i].shape_;
DepthwiseConvolutionOp<float> op;
op.Init(param, in_shape, out_shape);
op.Forward(ctx, inputs, req, outputs);
return;
}
#endif

#if MXNET_USE_CUDNN == 1
STATIC_ASSERT_CUDNN_VERSION_GE(7000);
// On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16).
int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype;

Expand Down Expand Up @@ -154,6 +123,22 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
}
})
#else
if (param.num_filter == param.num_group &&
param.layout.value() == mshadow::kNCHW &&
param.num_filter == inputs[conv::kData].shape_[1] &&
param.kernel.ndim() == 2 &&
param.dilate == mshadow::Shape2(1, 1) &&
dtype == mshadow::kFloat32) {
mxnet::ShapeVector in_shape(inputs.size());
mxnet::ShapeVector out_shape(1, outputs[0].shape_);
for (size_t i = 0; i < in_shape.size(); i++)
in_shape[i] = inputs[i].shape_;
DepthwiseConvolutionOp<float> op;
op.Init(param, in_shape, out_shape);
op.Forward(ctx, inputs, req, outputs);
return;
}

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
ConvolutionOp<gpu, DType> op;
op.Init(param);
Expand All @@ -174,39 +159,8 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob> &in_grad = outputs;
int dtype = out_grad.type_flag_;

#if CUDNN_MAJOR < 5
if (param.layout.value() != kNCW &&
param.layout.value() != kNCHW &&
param.layout.value() != kNCDHW) {
// Need CuDNN > 5.0 for layout support. use MXNet implementation
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
ConvolutionOp<gpu, DType> op;
op.Init(param);
op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
})
return;
}
#endif
#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7
if (param.num_filter == param.num_group &&
param.layout.value() == mshadow::kNCHW &&
param.num_filter == in_data[conv::kData].shape_[1] &&
param.kernel.ndim() == 2 &&
param.dilate == mshadow::Shape2(1, 1) &&
dtype == mshadow::kFloat32) {
// The first element stores out grad.
mxnet::ShapeVector in_shape(in_data.size());
mxnet::ShapeVector out_shape(1, out_grad.shape_);
for (size_t i = 0; i < in_shape.size(); i++)
in_shape[i] = in_data[i].shape_;
DepthwiseConvolutionOp<float> op;
op.Init(param, in_shape, out_shape);
op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
return;
}
#endif

#if MXNET_USE_CUDNN == 1
STATIC_ASSERT_CUDNN_VERSION_GE(7000);
// On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16).
int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype;

Expand Down Expand Up @@ -234,6 +188,23 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
}
})
#else
if (param.num_filter == param.num_group &&
param.layout.value() == mshadow::kNCHW &&
param.num_filter == in_data[conv::kData].shape_[1] &&
param.kernel.ndim() == 2 &&
param.dilate == mshadow::Shape2(1, 1) &&
dtype == mshadow::kFloat32) {
// The first element stores out grad.
mxnet::ShapeVector in_shape(in_data.size());
mxnet::ShapeVector out_shape(1, out_grad.shape_);
for (size_t i = 0; i < in_shape.size(); i++)
in_shape[i] = in_data[i].shape_;
DepthwiseConvolutionOp<float> op;
op.Init(param, in_shape, out_shape);
op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
return;
}

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
ConvolutionOp<gpu, DType> op;
op.Init(param);
Expand Down
Loading