diff --git a/src/operator/nn/mkldnn/mkldnn_act-inl.h b/src/operator/nn/mkldnn/mkldnn_act-inl.h new file mode 100644 index 000000000000..6bf30e3f3bbe --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_act-inl.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_act-inl.h + * \brief MKLDNN(Quantized) Activation operator based on subgraph + * /author Zhiyuan Huang +*/ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_ + + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include "../activation-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param); +mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( + const ActivationParam& param, bool is_train, + const mkldnn::memory &input_mem, int dtype); + +class MKLDNNActForward { + public: + const mkldnn::eltwise_forward::primitive_desc fwd_pd; + + MKLDNNActForward(const ActivationParam& param, bool is_train, + const NDArray &data, const mkldnn::memory &mem): fwd_pd( + GetActFwdDescImpl(param, is_train, mem, data.dtype())) {} + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output); + const mkldnn::eltwise_forward &GetFwd() const; + + private: + std::shared_ptr fwd_; + std::shared_ptr data_; + std::shared_ptr out_; +}; + +typedef ParamOpSign MKLDNNActSignature; +MKLDNNActForward &GetActForward(const ActivationParam& param, + const OpContext &ctx, const NDArray &in_data, + const mkldnn::memory &in_mem); + +void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index 8c64888b4608..9ce27fad4b19 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -32,8 +32,7 @@ #include #include #include "../../operator_common.h" -#include "../activation-inl.h" -#include "./mkldnn_base-inl.h" +#include "mkldnn_act-inl.h" #if MXNET_USE_MKLDNN == 1 @@ -58,7 +57,7 @@ bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) { return SupportMKLDNNAct(param); } -static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { +mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { switch (param.act_type) { case activation::kReLU: return mkldnn::algorithm::eltwise_relu; @@ -74,9 +73,7 @@ static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { } } -typedef std::shared_ptr mkldnn_act_pdesc_ptr; - -static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( +mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( const ActivationParam& param, bool is_train, const mkldnn::memory &input_mem, int dtype) { mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); @@ -84,65 +81,41 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( auto cpu_engine = data_mpd.get_engine(); auto alg = GetMKLDNNActAlgo(param); - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType alpha = 0; - mkldnn::eltwise_forward::desc desc = is_train - ? mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_training, - alg, data_md, alpha) - : mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_scoring, - alg, data_md, alpha); - return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); - }); - LOG(FATAL) << "Unsupported data type for MKLDNN activation"; - mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc( - mkldnn::prop_kind::forward_training, alg, data_md, 0.0); + + auto prop = is_train ? mkldnn::prop_kind::forward_training : + mkldnn::prop_kind::forward_scoring; + auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f); return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); } -typedef ParamOpSign MKLDNNActSignature; - -class MKLDNNActForward { - std::shared_ptr fwd; - std::shared_ptr data; - std::shared_ptr out; - - public: - const mkldnn::eltwise_forward::primitive_desc fwd_pd; - - MKLDNNActForward(const ActivationParam& param, bool is_train, - const NDArray &data, const mkldnn::memory &mem): fwd_pd( - GetActFwdDescImpl(param, is_train, mem, data.dtype())) { - } - - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - data.get_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - - CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc()); - if (this->out == nullptr) - this->out = std::shared_ptr(new mkldnn::memory( - fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out->set_data_handle(output.get_data_handle()); - - if (this->fwd == nullptr) { - this->fwd = std::shared_ptr( - new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data), - *this->out)); - } +void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { + if (this->data_ == nullptr) + this->data_ = std::make_shared(data.get_primitive_desc(), + data.get_data_handle()); + else + this->data_->set_data_handle(data.get_data_handle()); + + CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc()); + if (this->out_ == nullptr) + this->out_ = std::make_shared(fwd_pd.dst_primitive_desc(), + output.get_data_handle()); + else + this->out_->set_data_handle(output.get_data_handle()); + + if (this->fwd_ == nullptr) { + this->fwd_ = std::shared_ptr( + new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data_), + *this->out_)); } +} - const mkldnn::eltwise_forward &GetFwd() const { - return *fwd; - } -}; +const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const { + return *fwd_; +} -static MKLDNNActForward &GetActForward(const ActivationParam& param, - const OpContext &ctx, const NDArray &in_data, - const mkldnn::memory &in_mem) { +MKLDNNActForward &GetActForward(const ActivationParam& param, + const OpContext &ctx, const NDArray &in_data, + const mkldnn::memory &in_mem) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc new file mode 100644 index 000000000000..bc69cb5e9bf7 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_quantized_act.cc + * \brief MKLDNN(Quantized) Activation operator based on subgraph + * /author Zhiyuan Huang +*/ +#if MXNET_USE_MKLDNN == 1 + +#include "../../nn/mkldnn/mkldnn_act-inl.h" +#include "../quantization_utils.h" + +namespace mxnet { +namespace op { + +static void MKLDNNQuantizedActForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + CHECK(in_data[0].dtype() == mshadow::kUint8 || + in_data[0].dtype() == mshadow::kInt8) + << "_contrib_quantized_act op only supports uint8 and int8 as input " + "type"; + + MKLDNNActivationForward(attrs, ctx, in_data[0], req[0], out_data[0]); + out_data[1].data().dptr()[0] = in_data[1].data().dptr()[0]; + out_data[2].data().dptr()[0] = in_data[2].data().dptr()[0]; +} + +NNVM_REGISTER_OP(_contrib_quantized_act) +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", MKLDNNQuantizedActForward); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 5bd9e8af9038..7ff2999b0c15 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -89,11 +89,12 @@ std::vector OfflineParams(std::vector&& outputs, return outputs; } -inline bool NeedQuantize(const NodePtr node, - const std::unordered_set& excluded_nodes) { +inline NodePtr NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes) { + std::unordered_map quantized_node; static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); const auto& op = node->op(); + if (op && quantized_op_map.count(op)) { bool need = true; if (excluded_nodes.count(node->attrs.name)) { @@ -112,14 +113,24 @@ inline bool NeedQuantize(const NodePtr node, }); } } - return need; + + if (need) { + auto n_ptr = quantized_op_map[node->op()]; + auto tmp_node = n_ptr(node->attrs); + if (tmp_node->op()) { + quantized_node[node] = tmp_node; + } else { + quantized_node[node] = nullptr; + } + } else { + quantized_node[node] = nullptr; + } } - return false; + return quantized_node[node]; } Graph QuantizeGraph(Graph &&src) { static const auto& flist_outputs = nnvm::Op::GetAttr("FListOutputNames"); - static const auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static const auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); @@ -136,11 +147,9 @@ Graph QuantizeGraph(Graph &&src) { NodePtr new_node = Node::Create(); // If the currently visited node needs quantization, insert a quantize op node before the // current node and replace the current node with the quantized version in the new graph. - if (NeedQuantize(node, excluded_nodes)) { - auto fquantized_op = quantized_op_map[node->op()]; - // If the currently visited node's op registered the FQuantizedOp property, new_node is a - // quantizated version of a that op, such as quantized_conv2d. - new_node = fquantized_op(node->attrs); + auto tmp_node = NeedQuantize(node, excluded_nodes); + if (tmp_node) { + new_node = tmp_node; // add data into quantized op input for (size_t i = 0; i < node->inputs.size(); ++i) { diff --git a/src/operator/quantization/quantized_activation.cc b/src/operator/quantization/quantized_activation.cc new file mode 100644 index 000000000000..4ab74d0b1c3f --- /dev/null +++ b/src/operator/quantization/quantized_activation.cc @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file quantized_activation.cc +*/ +#include +#include "../nn/activation-inl.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +bool QuantizedActivationShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + CHECK_EQ(in_shape->size(), 3U); + if (shape_is_none(in_shape->at(0))) return false; + SHAPE_ASSIGN_CHECK(*in_shape, 1, TShape{1}); + SHAPE_ASSIGN_CHECK(*in_shape, 2, TShape{1}); + out_shape->clear(); + out_shape->push_back((*in_shape)[0]); + out_shape->push_back(TShape{1}); + out_shape->push_back(TShape{1}); + return true; +} + +bool QuantizedActivationType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const ActivationParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_type->size(), 3U); + CHECK_EQ(out_type->size(), 3U); + if (param.act_type == activation::kReLU) { + TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt8); + } else { + LOG(FATAL) << "_contrib_quantized_act only supports act_type=relu for now"; + } + TYPE_ASSIGN_CHECK(*in_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_type, 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); + return true; +} + +inline static bool QuantizedActivationStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3); + + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + const ActivationParam ¶m = nnvm::get(attrs.parsed); + if (dev_mask == mshadow::cpu::kDevMask && param.act_type == activation::kReLU) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#else + CHECK_EQ(out_attrs->size(), 3); +#endif + for (int& out_attr : *out_attrs) + out_attr = kDefaultStorage; + return true; +} + +NNVM_REGISTER_OP(_contrib_quantized_act) +.describe(R"code(Activation operator for input and output data type of int8. +The input and output data comes with min and max thresholds for quantizing +the float32 data into int8. + +.. Note:: + This operator only supports forward propogation. DO NOT use it in training. + This operator only supports `relu`)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "min_data", "max_data"}; + }) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "min_output", "max_output"}; + }) +.set_attr("FInferType", QuantizedActivationType) +.set_attr("FInferShape", QuantizedActivationShape) +.set_attr("FInferStorageType", QuantizedActivationStorageType) +.set_attr("FNeedRequantize", + [](const NodeAttrs& attrs) { + const ActivationParam& param = nnvm::get(attrs.parsed); + CHECK(param.act_type == activation::kReLU) + << "_contrib_quantized_act only supports act_type=relu for now"; + return false; + }) +.add_argument("data", "NDArray-or-Symbol", "Input data.") +.add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.") +.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.") +.add_arguments(ActivationParam::__FIELDS__()); + +NNVM_REGISTER_OP(Activation) +.set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { + ActivationParam param; + param.Init(attrs.dict); + nnvm::NodePtr node = nnvm::Node::Create(); + if (param.act_type == activation::kReLU) { + node->attrs.op = Op::Get("_contrib_quantized_act"); + node->attrs.name = "quantized_" + attrs.name; + } else { + node->attrs.op = nullptr; + node->attrs.name = attrs.name; + } + node->attrs.dict = attrs.dict; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; +}); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 757df81e1607..2761e77fb0c1 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -414,6 +414,57 @@ def check_quantized_flatten(shape, qdtype): check_quantized_flatten((10, 15, 18), qdtype) check_quantized_flatten((3, 4, 23, 23), qdtype) +@with_seed() +def test_quantized_act(): + def check_quantized_act(data_shape, qdtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_act for native cpu since it is not supported yet') + return + elif qdtype == 'int8' and is_test_for_mkldnn(): + print('skipped testing quantized_act for mkldnn cpu int8 since it is not supported yet') + return + elif is_test_for_gpu(): + print('skipped testing quantized_act for gpu since it is not supported yet') + return + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + act_fp32 = mx.sym.Activation(data=data, act_type='relu', name='relu') + arg_shapes, _, _ = act_fp32.infer_shape(data=data_shape) + arg_names = act_fp32.list_arguments() + act_fp32_exe = act_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + + act_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, + high=data_high, shape=data_shape).astype(qdtype) + output = act_fp32_exe.forward()[0] + + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype) + min_data = mx.sym.Variable(name='min_data') + max_data = mx.sym.Variable(name='max_data') + quantized_act = mx.sym.contrib.quantized_act(data=qdata, min_data=min_data, max_data=max_data, act_type='relu') + act_int8_exe = quantized_act.simple_bind(ctx=mx.current_context(), grad_req='null') + qarg_names = quantized_act.list_arguments() + + act_int8_exe.arg_dict[qarg_names[0]][:] = act_fp32_exe.arg_dict[arg_names[0]].astype(qdtype) + quantized_range_min = mx.nd.min(act_int8_exe.arg_dict[qarg_names[0]][:]) + quantized_range_max = mx.nd.max(act_int8_exe.arg_dict[qarg_names[0]][:]) + act_int8_exe.arg_dict[qarg_names[1]][:] = quantized_range_min.astype(qdtype) + act_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range_max.astype(qdtype) + qoutput, min_range, max_range = act_int8_exe.forward() + + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + assert_almost_equal(min_range.asscalar(), quantized_range_min.asscalar()) + assert_almost_equal(max_range.asscalar(), quantized_range_max.asscalar()) + + for qdtype in ['int8', 'uint8']: + check_quantized_act((10,), qdtype) + check_quantized_act((10, 15), qdtype) + check_quantized_act((10, 15, 18), qdtype) + check_quantized_act((3, 4, 23, 23), qdtype) @with_seed() def test_quantize_params(): @@ -634,7 +685,9 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): arg_params, aux_params = mod.get_params() excluded_names = [] if mx.current_context() == mx.cpu(): - excluded_names += ['fc'] + excluded_names += ['fc', 'conv1'] + if mx.current_context() == mx.gpu(): + excluded_names += ['relu0', 'relu1'] excluded_names += ['concat'] optional_names = ['pool0']