From e10026768d56e82b7fd4d5799af1cbb7313ef154 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 12 Mar 2018 22:57:04 -0700 Subject: [PATCH] Fix registering quantized nn ops --- src/operator/quantization/quantized_conv.cu | 8 ++++---- src/operator/quantization/quantized_fully_connected.cu | 2 -- src/operator/quantization/quantized_pooling.cu | 8 ++++---- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/operator/quantization/quantized_conv.cu b/src/operator/quantization/quantized_conv.cu index 39d7e078802e..b45124ad3c6b 100644 --- a/src/operator/quantization/quantized_conv.cu +++ b/src/operator/quantization/quantized_conv.cu @@ -23,7 +23,6 @@ * \brief * \author Ziheng Jiang, Jun Wu */ -#if MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 #include "../nn/convolution-inl.h" #include "./quantization_utils.h" #include "../tensor/matrix_op-inl.h" @@ -50,6 +49,7 @@ struct QuantizedBiasAddKernel { } }; +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 template class QuantizedCuDNNConvOp { public: @@ -260,6 +260,7 @@ class QuantizedCuDNNConvOp { float alpha_ = 1.0f; float beta_ = 0.0f; }; // class QuantizedCuDNNConvOp +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -269,7 +270,7 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param = nnvm::get(attrs.parsed); CHECK_EQ(param.kernel.ndim(), 2U) << "QuantizedConvForward only supports 2D convolution for now"; -#if MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 typedef QuantizedCuDNNConvOp QuantizedConvOpInt8; #if DMLC_CXX11_THREAD_LOCAL static thread_local QuantizedConvOpInt8 op; @@ -280,7 +281,7 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs, op.Forward(ctx, inputs, req, outputs); #else LOG(FATAL) << "QuantizedConvForward only supports cudnnConvolutionForward for now"; -#endif // MXNET_USE_CUDNN +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 } NNVM_REGISTER_OP(_contrib_quantized_conv) @@ -288,4 +289,3 @@ NNVM_REGISTER_OP(_contrib_quantized_conv) } // namespace op } // namespace mxnet -#endif diff --git a/src/operator/quantization/quantized_fully_connected.cu b/src/operator/quantization/quantized_fully_connected.cu index e66993f23506..ac7ba1e21df8 100644 --- a/src/operator/quantization/quantized_fully_connected.cu +++ b/src/operator/quantization/quantized_fully_connected.cu @@ -23,7 +23,6 @@ * \brief * \author Ziheng Jiang, Jun Wu */ -#if MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 #include "./quantization_utils.h" #include "../mxnet_op.h" #include "../nn/fully_connected-inl.h" @@ -121,4 +120,3 @@ NNVM_REGISTER_OP(_contrib_quantized_fully_connected) } // namespace op } // namespace mxnet -#endif diff --git a/src/operator/quantization/quantized_pooling.cu b/src/operator/quantization/quantized_pooling.cu index c687e71ae889..1bb08f470de7 100644 --- a/src/operator/quantization/quantized_pooling.cu +++ b/src/operator/quantization/quantized_pooling.cu @@ -21,7 +21,6 @@ * Copyright (c) 2017 by Contributors * \file quantized_pooling.cu */ -#if MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 #include #include #include "../nn/pooling-inl.h" @@ -30,6 +29,7 @@ namespace mxnet { namespace op { +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 template class QuantizedCuDNNPoolingOp { public: @@ -115,6 +115,7 @@ class QuantizedCuDNNPoolingOp { cudnnTensorDescriptor_t out_desc_; cudnnPoolingDescriptor_t pool_desc_; }; // class QuantizedCuDNNPoolingOp +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -124,7 +125,7 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs, const PoolingParam& param = nnvm::get(attrs.parsed); CHECK_EQ(param.kernel.ndim(), 2U) << "QuantizedPoolingForward only supports 2D convolution for now"; -#if MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 #if DMLC_CXX11_THREAD_LOCAL static thread_local QuantizedCuDNNPoolingOp op; #else @@ -134,7 +135,7 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs, op.Forward(ctx.get_stream(), inputs, req, outputs); #else LOG(FATAL) << "QuantizedPoolingForward only supports cudnnPoolingForward for now"; -#endif // MXNET_USE_CUDNN +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 } NNVM_REGISTER_OP(_contrib_quantized_pooling) @@ -142,4 +143,3 @@ NNVM_REGISTER_OP(_contrib_quantized_pooling) } // namespace op } // namespace mxnet -#endif