Skip to content

Commit

Permalink
Fix registering quantized nn ops
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Mar 13, 2018
1 parent 0a48f25 commit 667e261
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/operator/quantization/quantized_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
* \brief
* \author Ziheng Jiang, Jun Wu
*/
#if MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#include "../nn/convolution-inl.h"
#include "./quantization_utils.h"
#include "../tensor/matrix_op-inl.h"
Expand All @@ -50,6 +49,7 @@ struct QuantizedBiasAddKernel {
}
};

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
template<typename SrcType, typename DstType, typename CmpType>
class QuantizedCuDNNConvOp {
public:
Expand Down Expand Up @@ -260,6 +260,7 @@ class QuantizedCuDNNConvOp {
float alpha_ = 1.0f;
float beta_ = 0.0f;
}; // class QuantizedCuDNNConvOp
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6

void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -269,7 +270,7 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
CHECK_EQ(param.kernel.ndim(), 2U)
<< "QuantizedConvForward<gpu> only supports 2D convolution for now";
#if MXNET_USE_CUDNN == 1
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
typedef QuantizedCuDNNConvOp<int8_t, float, int32_t> QuantizedConvOpInt8;
#if DMLC_CXX11_THREAD_LOCAL
static thread_local QuantizedConvOpInt8 op;
Expand All @@ -280,12 +281,11 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs,
op.Forward(ctx, inputs, req, outputs);
#else
LOG(FATAL) << "QuantizedConvForward<gpu> only supports cudnnConvolutionForward for now";
#endif // MXNET_USE_CUDNN
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
}

NNVM_REGISTER_OP(_contrib_quantized_conv)
.set_attr<FCompute>("FCompute<gpu>", QuantizedConvForwardGPU);

} // namespace op
} // namespace mxnet
#endif
2 changes: 0 additions & 2 deletions src/operator/quantization/quantized_fully_connected.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
* \brief
* \author Ziheng Jiang, Jun Wu
*/
#if MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#include "./quantization_utils.h"
#include "../mxnet_op.h"
#include "../nn/fully_connected-inl.h"
Expand Down Expand Up @@ -121,4 +120,3 @@ NNVM_REGISTER_OP(_contrib_quantized_fully_connected)

} // namespace op
} // namespace mxnet
#endif
8 changes: 4 additions & 4 deletions src/operator/quantization/quantized_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
* Copyright (c) 2017 by Contributors
* \file quantized_pooling.cu
*/
#if MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#include <mxnet/operator_util.h>
#include <vector>
#include "../nn/pooling-inl.h"
Expand All @@ -30,6 +29,7 @@
namespace mxnet {
namespace op {

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
template<typename DType>
class QuantizedCuDNNPoolingOp {
public:
Expand Down Expand Up @@ -115,6 +115,7 @@ class QuantizedCuDNNPoolingOp {
cudnnTensorDescriptor_t out_desc_;
cudnnPoolingDescriptor_t pool_desc_;
}; // class QuantizedCuDNNPoolingOp
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6

void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -124,7 +125,7 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs,
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
CHECK_EQ(param.kernel.ndim(), 2U)
<< "QuantizedPoolingForward<gpu> only supports 2D convolution for now";
#if MXNET_USE_CUDNN == 1
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#if DMLC_CXX11_THREAD_LOCAL
static thread_local QuantizedCuDNNPoolingOp<int8_t> op;
#else
Expand All @@ -134,12 +135,11 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs,
op.Forward(ctx.get_stream<gpu>(), inputs, req, outputs);
#else
LOG(FATAL) << "QuantizedPoolingForward<gpu> only supports cudnnPoolingForward for now";
#endif // MXNET_USE_CUDNN
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
}

NNVM_REGISTER_OP(_contrib_quantized_pooling)
.set_attr<FCompute>("FCompute<gpu>", QuantizedPoolingForwardGPU);

} // namespace op
} // namespace mxnet
#endif

0 comments on commit 667e261

Please sign in to comment.