diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py index 5cb74ba11a89..231cc99f93bc 100644 --- a/example/ssd/quantization.py +++ b/example/ssd/quantization.py @@ -139,7 +139,6 @@ def save_params(fname, arg_params, aux_params, logger=None): mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} if calib_mode == 'none': - logger.info('Quantizing FP32 model %s' % args.model) qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, diff --git a/src/operator/nn/mkldnn/mkldnn_concat-inl.h b/src/operator/nn/mkldnn/mkldnn_concat-inl.h new file mode 100644 index 000000000000..d3866cc3d23d --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_concat-inl.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_concat-inl.h + * \brief + * \author Wenting Jiang +*/ +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ + + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include "../concat-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNConcatFwd { + public: + mkldnn::concat::primitive_desc fwd_pd; + + MKLDNNConcatFwd(int concat_dim, const std::vector &data_md) + : fwd_pd(concat_dim, data_md) { + data.resize(data_md.size()); + } + + void SetNewMem(const std::vector &in_data, const mkldnn::memory &output); + + const mkldnn::concat &GetFwd() const; + + private: + std::shared_ptr fwd; + std::vector> data; + std::vector data_mem; + std::shared_ptr out; +}; + +static MKLDNNConcatFwd &GetConcatForward( + int concat_dim, const std::vector &in_data, + const std::vector &data_md) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + OpSignature key; + key.AddSign(concat_dim); + key.AddSign(in_data); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNConcatFwd fwd(concat_dim, data_md); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index 03eeb61eccbf..8e2b57781a18 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -22,76 +22,36 @@ * \brief * \author Wenting Jiang */ -#include "../concat-inl.h" -#include "./mkldnn_ops-inl.h" -#include "./mkldnn_base-inl.h" #if MXNET_USE_MKLDNN == 1 +#include "mkldnn_concat-inl.h" + namespace mxnet { namespace op { -class MKLDNNConcatFwd { - std::shared_ptr fwd; - std::vector> data; - std::vector data_mem; - std::shared_ptr out; - - public: - mkldnn::concat::primitive_desc fwd_pd; - - MKLDNNConcatFwd( - int concat_dim, - const std::vector &data_md): fwd_pd(concat_dim, data_md) { - data.resize(data_md.size()); - } - - void SetNewMem(const std::vector &in_data, - const mkldnn::memory &output) { - CHECK_EQ(in_data.size(), data.size()); - for (size_t i = 0; i < data.size(); i++) { - if (this->data[i] == nullptr) { - this->data[i] = std::shared_ptr(new mkldnn::memory( - in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); - this->data_mem.push_back(*this->data[i]); - } else { - this->data[i]->set_data_handle(in_data[i]->get_data_handle()); - } +void MKLDNNConcatFwd::SetNewMem(const std::vector &in_data, + const mkldnn::memory &output) { + CHECK_EQ(in_data.size(), data.size()); + for (size_t i = 0; i < data.size(); i++) { + if (this->data[i] == nullptr) { + this->data[i] = std::shared_ptr( + new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); + this->data_mem.push_back(*this->data[i]); + } else { + this->data[i]->set_data_handle(in_data[i]->get_data_handle()); } - if (this->out == nullptr) - this->out = std::shared_ptr(new mkldnn::memory( - fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out->set_data_handle(output.get_data_handle()); - - if (this->fwd == nullptr) - fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out)); } + if (this->out == nullptr) + this->out = std::shared_ptr( + new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); - const mkldnn::concat &GetFwd() const { - return *fwd; - } -}; - -static MKLDNNConcatFwd &GetConcatForward( - int concat_dim, const std::vector &in_data, - const std::vector &data_md) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; -#else - static MX_THREAD_LOCAL std::unordered_map fwds; -#endif - OpSignature key; - key.AddSign(concat_dim); - key.AddSign(in_data); - - auto it = fwds.find(key); - if (it == fwds.end()) { - MKLDNNConcatFwd fwd(concat_dim, data_md); - it = AddToCache(&fwds, key, fwd); - } - return it->second; + if (this->fwd == nullptr) fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out)); } +const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd; } + void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc new file mode 100644 index 000000000000..d9e884e82806 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file quantized_concat.cc + * \brief + */ + +#if MXNET_USE_MKLDNN == 1 +#include "../../nn/mkldnn/mkldnn_concat-inl.h" +#include "../quantization_utils.h" + +namespace mxnet { +namespace op { + +namespace quantized_concat_enum { +enum QuantizedConcatOutputs { kOut, kMin, kMax }; +} + +static float GetScale(const NDArray& data, float min, float max) { + auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range; + return data_range / MaxAbs(min, max); +} + +static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + const ConcatParam& param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_data.size(), static_cast(param_.num_args * 3)); + CHECK_EQ(out_data.size(), 3U); + // Collect data min/max and output_neg_min, output_pos_max + std::vector data_min(param_.num_args); + std::vector data_max(param_.num_args); + float output_neg_min = 0.f; // 0.f is the maximum for output_neg_min + float output_pos_max = 0.f; // 0.f is the minimum for output_pos_max + for (int i = 0; i < param_.num_args; ++i) { + data_min[i] = in_data[param_.num_args + 2 * i].data().dptr()[0]; + if (data_min[i] < output_neg_min) output_neg_min = data_min[i]; + data_max[i] = in_data[param_.num_args + 2 * i + 1].data().dptr()[0]; + if (data_max[i] > output_pos_max) output_pos_max = data_max[i]; + } + out_data[quantized_concat_enum::kMin].data().dptr()[0] = output_neg_min; + out_data[quantized_concat_enum::kMax].data().dptr()[0] = output_pos_max; + auto out_scale = GetScale(out_data[quantized_concat_enum::kOut], output_neg_min, output_pos_max); + std::vector data_md; + std::vector data_mem; + // new_data_mem is for auto-free new created mkldnn memory + std::vector> new_data_mem; + for (int i = 0; i < param_.num_args; ++i) { + auto i_scale = GetScale(in_data[i], data_min[i], data_max[i]); + if (i_scale == out_scale) { + auto mem = in_data[i].GetMKLDNNData(); + data_mem.push_back(mem); + data_md.push_back(mem->get_primitive_desc()); + } else { + auto mem = in_data[i].GetMKLDNNData(); + auto pd = mem->get_primitive_desc(); + const auto rescaled_mem = std::make_shared(pd); + new_data_mem.push_back(rescaled_mem); + std::vector reorder_scale = {out_scale / i_scale}; + primitive_attr reorder_attr; + reorder_attr.set_int_output_round_mode(round_mode::round_nearest); + reorder_attr.set_output_scales(0, reorder_scale); + const auto reorder_pd = mkldnn::reorder::primitive_desc(pd, pd, reorder_attr); + MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem)); + data_mem.push_back(rescaled_mem.get()); + data_md.push_back(pd); + } + } + MKLDNNConcatFwd& fwd = GetConcatForward(param_.dim, in_data, data_md); + mxnet::mkldnn_output_t out_mem = + CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_primitive_desc(), + req[concat_enum::kOut]); + fwd.SetNewMem(data_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + CommitOutput(out_data[concat_enum::kOut], out_mem); + MKLDNNStream::Get()->Submit(); +} + +inline static bool ConcatStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, + DispatchMode* dispatch_mode, std::vector* in_attrs, + std::vector* out_attrs) { + const ConcatParam& param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), static_cast(param_.num_args * 3)); + CHECK_EQ(out_attrs->size(), 3U); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} + +NNVM_REGISTER_OP(_contrib_quantized_concat) +.set_attr("FInferStorageType", ConcatStorageType) +.set_attr("FComputeEx", MKLDNNQuantizedConcatForward) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("TIsMKLDNN", true); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h index 5b096ac0057a..ee7112205892 100644 --- a/src/operator/quantization/quantization_utils.h +++ b/src/operator/quantization/quantization_utils.h @@ -31,6 +31,8 @@ namespace mxnet { namespace op { +static const size_t kUint8Range = 255; +static const size_t kInt8Range = 127; template MSHADOW_XINLINE int Sign(T val) { diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc new file mode 100644 index 000000000000..3504df82d243 --- /dev/null +++ b/src/operator/quantization/quantized_concat.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file quantized_concat.cc + * \brief +*/ + +#include "../nn/concat-inl.h" + +namespace mxnet { +namespace op { + +static bool ConcatShape(const nnvm::NodeAttrs& attrs, std::vector* in_shape, + std::vector* out_shape) { + const ConcatParam& param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), static_cast(param_.num_args * 3)); + CHECK_EQ(out_shape->size(), 3U); + TShape dshape; + index_t size = 0; + bool has_zero = false; + int axis = -1; + for (int i = 0; i < param_.num_args; ++i) { + TShape tmp = (*in_shape)[i]; + if (tmp.ndim()) { + axis = CheckAxis(param_.dim, tmp.ndim()); + has_zero = tmp[axis] == 0 || has_zero; + size += tmp[axis]; + tmp[axis] = 0; + shape_assign(&dshape, tmp); + } + } + + TShape tmp = (*out_shape)[0]; + if (tmp.ndim()) { + axis = CheckAxis(param_.dim, tmp.ndim()); + tmp[axis] = 0; + shape_assign(&dshape, tmp); + } + + if (dshape.ndim() == 0) return false; + + for (int i = 0; i < param_.num_args; ++i) { + CHECK(shape_assign(&(*in_shape)[i], dshape)) + << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; + } + + if (!has_zero) dshape[axis] = size; + CHECK(shape_assign(&(*out_shape)[0], dshape)) + << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; + + for (int i = param_.num_args; i < param_.num_args * 3; ++i) { + SHAPE_ASSIGN_CHECK(*in_shape, i, TShape{1}); + } + SHAPE_ASSIGN_CHECK(*out_shape, 1, TShape{1}); + SHAPE_ASSIGN_CHECK(*out_shape, 2, TShape{1}); + return dshape.Size() != 0; +} + +static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector* in_type, + std::vector* out_type) { + const ConcatParam& param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_type->size(), static_cast(param_.num_args * 3)); + CHECK_EQ(out_type->size(), 3U); + int dtype = mshadow::kUint8; + + for (int i = 0; i < param_.num_args; ++i) { + if (in_type->at(i) == mshadow::kInt8) { + dtype = mshadow::kInt8; + } else { + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kUint8); + } + } + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); + + return true; +} + +NNVM_REGISTER_OP(_contrib_quantized_concat) +.describe(R"code(Joins input arrays along a given axis. + +The dimensions of the input arrays should be the same except the axis along +which they will be concatenated. +The dimension of the output array along the concatenated axis will be equal +to the sum of the corresponding dimensions of the input arrays. +All inputs with different min/max will be rescaled by using largest [min, max] pairs. +If any input holds int8, then the output will be int8. Otherwise output will be uint8. + +)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args * 3; +}) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + std::vector ret; + for (int i = 0; i < params.num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i)); + } + for (int i = 0; i < params.num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i) + "_min"); + ret.push_back(std::string("arg") + std::to_string(i) + "_max"); + } + return ret; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output", "min_output", "max_output"}; +}) +.set_attr("FInferType", ConcatType) +.set_attr("FInferShape", ConcatShape) +.set_attr("key_var_num_args", "num_args") +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") +.add_arguments(ConcatParam::__FIELDS__()); + +NNVM_REGISTER_OP(Concat) +.set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_concat"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; +}); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h index 8675446f5a14..b44f2fb0e31e 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h @@ -36,9 +36,6 @@ struct MKLDNNConvFusionParam { std::shared_ptr bn_param; }; -static const size_t uint8_range = 255; -static const size_t int8_range = 127; - enum MKLDNNConvOpOutputs { kOut, kMin, kMax }; } // namespace op diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index a1083d09b7b5..dfa98d1f5ee9 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -109,7 +109,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); - weight_scales->at(c) = int8_range / weight_range; + weight_scales->at(c) = kInt8Range / weight_range; const DType *fp_ptr = weight_ptr + c * offset; int8_t *quan_ptr = quan_weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { @@ -125,7 +125,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, } weight_scales->resize(1); DType weight_range = MaxAbs(total_min, total_max); - weight_scales->at(0) = int8_range / weight_range; + weight_scales->at(0) = kInt8Range / weight_range; #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { const DType *fp_ptr = weight_ptr + c * offset; @@ -327,7 +327,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, // Quantize weight and bias. if (mkldnn_param.quantized) { CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8); - auto data_range = (data.dtype() == mshadow::kInt8) ? int8_range : uint8_range; + auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range; float data_scale = data_range / MaxAbs(cached_data_min_, cached_data_max_); MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { QuantizeConvWeightBias(&cached_weight_, &cached_bias_, @@ -346,12 +346,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, LOG(FATAL) << "Can't handle negetive value for QuantizeData"; } if (mkldnn_param.with_sum) { - auto quantized_sum_range = cached_sum_min_ < 0 ? int8_range : uint8_range; + auto quantized_sum_range = cached_sum_min_ < 0 ? kInt8Range : kUint8Range; sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_); } if (post_requantize) { quantized_out_range = - IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range; + IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range; out_range = MaxAbs(*out_min_ptr, *out_max_ptr); output_scale = quantized_out_range / out_range; full_conv_param.requantize_scales.resize(channel); diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 71784dcd3bf1..be6feaeb94a6 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -66,7 +66,7 @@ def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape): output.wait_to_read() return mod.get_outputs() -def check_quantize(sym, data_shape): +def check_quantize(sym, data_shape, check_conv=True): fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') sym_sg = sym.get_backend_symbol("MKLDNN") @@ -106,7 +106,8 @@ def check_quantize(sym, data_shape): calib_quantize_op=True, num_calib_examples=5) qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") - check_qsym_calibrated(qsym) + if check_conv: + check_qsym_calibrated(qsym) quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) for i in range(len(ref_out)): assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) @@ -229,6 +230,15 @@ def conv_bn_sum_relu(no_bias, data_shape): relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu") return relu, conv_bn_add_relu_attr +# single concat case +def single_concat(data_shape, input_num, dim): + data, weight = head_symbol(data_shape) + inputs = [] + for i in range(input_num): + inputs.append(data) + concat = mx.symbol.Concat(*inputs, name="concat", dim=dim) + return concat + def tail_neg_symbol(sym1, sym2): fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1') fc2 = mx.sym.FullyConnected(data=sym2, num_hidden=10, flatten=True, name='fc2') @@ -463,6 +473,15 @@ def test_pos_conv_bn_sum_relu(): net, attrs = conv_bn_sum_relu(True, data_shape) check_fusion(net, data_shape, attrs) +def test_pos_single_concat(): + for data_shape in DATA_SHAPE: + net = single_concat(data_shape, 2, 1) + check_quantize(net, data_shape, False) + net = single_concat(data_shape, 4, 2) + check_quantize(net, data_shape, False) + net = single_concat(data_shape, 4, 3) + check_quantize(net, data_shape, False) + @with_seed() def test_neg_conv_bn(): for data_shape in DATA_SHAPE: