From fb1d3950d4e293853ea5e0dde5034275b1e092b0 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Sat, 30 Oct 2021 16:56:27 +0200 Subject: [PATCH] Add quantized batch_dot (#20680) * Add quantized batch_dot * Fix sanity * Fix names for post quantize fuse * Fixes * update amp list * fix sanity --- python/mxnet/amp/lists/symbol_bf16.py | 1 + python/mxnet/amp/lists/symbol_fp16.py | 1 + src/operator/nn/dnnl/dnnl_batch_dot-inl.h | 73 +++++++- src/operator/nn/dnnl/dnnl_batch_dot.cc | 157 ++++++++++++---- src/operator/nn/dnnl/dnnl_ops-inl.h | 1 + src/operator/subgraph/dnnl/dnnl_batch_dot.cc | 176 ++++++++++++++++++ .../subgraph/dnnl/dnnl_batch_dot_property.h | 99 ++++++++++ ...h => dnnl_matmul_post_quantize_property.h} | 34 ++-- .../subgraph/dnnl/dnnl_subgraph_property.cc | 10 +- .../dnnl/dnnl_transformer_qk_property.h | 2 +- src/operator/tensor/dot-inl.h | 19 +- src/operator/tensor/dot.cc | 5 +- ...er_subgraph.py => test_matmul_subgraph.py} | 51 ++++- 13 files changed, 551 insertions(+), 78 deletions(-) create mode 100644 src/operator/subgraph/dnnl/dnnl_batch_dot.cc create mode 100644 src/operator/subgraph/dnnl/dnnl_batch_dot_property.h rename src/operator/subgraph/dnnl/{dnnl_transformer_post_quantize_property.h => dnnl_matmul_post_quantize_property.h} (82%) rename tests/python/dnnl/subgraphs/{test_transformer_subgraph.py => test_matmul_subgraph.py} (67%) diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py index 86f5b0aabe72..dd545a778578 100644 --- a/python/mxnet/amp/lists/symbol_bf16.py +++ b/python/mxnet/amp/lists/symbol_bf16.py @@ -362,6 +362,7 @@ 'zeros_like', '_sg_onednn_conv', '_sg_onednn_fully_connected', + '_sg_onednn_batch_dot', 'broadcast_mul', 'Convolution_v1', 'IdentityAttachKLSparseReg', diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index 45ac56c4b346..e54b523e1fe6 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -620,6 +620,7 @@ '_sg_onednn_fully_connected', '_sg_onednn_selfatt_qk', '_sg_onednn_selfatt_valatt', + '_sg_onednn_batch_dot' ]) # Functions that have to be cast to FP32 only for diff --git a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h index 2c07a32f2153..0d5d72828462 100644 --- a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h +++ b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h @@ -38,22 +38,67 @@ namespace mxnet { namespace op { +enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max }; +enum DotOut { out = 0, out_min, out_max }; + +struct DNNLDotParam : public dmlc::Parameter { + bool transpose_a; + bool transpose_b; + bool quantized; + + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + bool enable_float_output; // min float value calculated from calibration dataset + DMLC_DECLARE_PARAMETER(DNNLDotParam) { + DMLC_DECLARE_FIELD(transpose_a) + .describe("If true then transpose the first input before dot.") + .set_default(false); + DMLC_DECLARE_FIELD(transpose_b) + .describe("If true then transpose the second input before dot.") + .set_default(false); + DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization"); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe( + "The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized convolution op to calculate primitive scale"); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe( + "The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized convolution op to calculate primitive scale"); + DMLC_DECLARE_FIELD(enable_float_output) + .set_default(false) + .describe("Whether to enable float32 output."); + } + + bool operator==(const DNNLDotParam& other) const { + return this->transpose_a == other.transpose_a && this->transpose_b == other.transpose_b && + this->quantized == other.quantized && this->min_calib_range == other.min_calib_range && + this->max_calib_range == other.max_calib_range; + } +}; + using batch_dot_fwd_t = dnnl::matmul; using batch_dot_fwd_pd_t = dnnl::matmul::primitive_desc; -typedef ParamOpSign BatchDotSignature; +typedef ParamOpSign BatchDotSignature; class DNNLBatchDotFwd { public: - static DNNLBatchDotFwd& GetCached(const DotParam& param, + static DNNLBatchDotFwd& GetCached(const DNNLDotParam& param, const std::vector& inputs, const std::vector& outputs); - DNNLBatchDotFwd(const DotParam& param, + DNNLBatchDotFwd(const DNNLDotParam& param, const std::vector& inputs, const std::vector& outputs); - void Execute(const std::vector& inputs, + void Execute(const OpContext& ctx, + const DNNLDotParam& param, + const std::vector& inputs, const std::vector& req, const std::vector& outputs); @@ -62,6 +107,26 @@ class DNNLBatchDotFwd { std::shared_ptr fwd_pd; }; +template +void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + DNNLDotParam dnnl_param; + if (!subgraph) { + const DotParam& param = nnvm::get(attrs.parsed); + dnnl_param.transpose_a = param.transpose_a; + dnnl_param.transpose_b = param.transpose_b; + dnnl_param.quantized = false; + } else { + dnnl_param = nnvm::get(attrs.parsed); + } + + DNNLBatchDotFwd& fwd = DNNLBatchDotFwd::GetCached(dnnl_param, inputs, outputs); + fwd.Execute(ctx, dnnl_param, inputs, req, outputs); +} + } // namespace op } // namespace mxnet #endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/nn/dnnl/dnnl_batch_dot.cc b/src/operator/nn/dnnl/dnnl_batch_dot.cc index bb9f911ee8ec..26a1acef3763 100644 --- a/src/operator/nn/dnnl/dnnl_batch_dot.cc +++ b/src/operator/nn/dnnl/dnnl_batch_dot.cc @@ -25,27 +25,21 @@ #if MXNET_USE_ONEDNN == 1 #include "./dnnl_batch_dot-inl.h" +#include "../../quantization/quantization_utils.h" namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(DNNLDotParam); + bool SupportDNNLBatchDot(const std::vector& inputs, const NDArray& output) { - return inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 && + return inputs[DotIn::lhs].shape().Size() != 0 && inputs[DotIn::rhs].shape().Size() != 0 && output.shape().Size() != 0 && - (inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kBfloat16); -} - -void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const DotParam& param = nnvm::get(attrs.parsed); - DNNLBatchDotFwd& fwd = DNNLBatchDotFwd::GetCached(param, inputs, outputs); - fwd.Execute(inputs, req, outputs); + (inputs[DotIn::lhs].dtype() == mshadow::kFloat32 || + inputs[DotIn::lhs].dtype() == mshadow::kBfloat16); } -DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param, +DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DNNLDotParam& param, const std::vector& inputs, const std::vector& outputs) { using batch_dot_fwd_map = std::unordered_map; @@ -56,9 +50,9 @@ DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param, #endif BatchDotSignature key(param); - key.AddSign(inputs[0]); - key.AddSign(inputs[1]); - key.AddSign(outputs[0]); + key.AddSign(inputs[DotIn::lhs]); + key.AddSign(inputs[DotIn::rhs]); + key.AddSign(outputs[DotOut::out]); auto it = fwds.find(key); if (it == fwds.end()) { @@ -68,14 +62,40 @@ DNNLBatchDotFwd& DNNLBatchDotFwd::GetCached(const DotParam& param, return it->second; } -DNNLBatchDotFwd::DNNLBatchDotFwd(const DotParam& param, +dnnl::primitive_attr GetQuantizationAttributes(const DNNLDotParam& param, + const std::vector& inputs, + const std::vector& outputs) { + dnnl::primitive_attr attr; + float out_scale_ = 1.f; + float lhs_scale_ = GetQuantizeScale(inputs[DotIn::lhs].dtype(), + inputs[DotIn::lhs_min].data().dptr()[0], + inputs[DotIn::lhs_max].data().dptr()[0]); + float rhs_scale_ = GetQuantizeScale(inputs[DotIn::rhs].dtype(), + inputs[DotIn::rhs_min].data().dptr()[0], + inputs[DotIn::rhs_max].data().dptr()[0]); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + // fused requantize => output is int + out_scale_ = GetQuantizeScale(outputs[DotOut::out].dtype(), + param.min_calib_range.value(), + param.min_calib_range.value()) / + lhs_scale_ / rhs_scale_; + attr.set_output_scales(0, {out_scale_}); + } else if (param.enable_float_output) { + out_scale_ = 1.0 / lhs_scale_ / rhs_scale_; + attr.set_output_scales(0, {out_scale_}); + } + + return attr; +} + +DNNLBatchDotFwd::DNNLBatchDotFwd(const DNNLDotParam& param, const std::vector& inputs, const std::vector& outputs) { - auto shape = inputs[0].shape(); - auto ndim = shape.ndim(); - auto bigDim = shape[0]; + auto lhs_shape = inputs[DotIn::lhs].shape(); + auto ndim = lhs_shape.ndim(); + auto bigDim = lhs_shape[0]; for (size_t i = 1; i < ndim - 2; ++i) { - bigDim *= shape[i]; + bigDim *= lhs_shape[i]; } auto GetMemoryDesc = [&ndim, &bigDim](const NDArray& tensor, const bool transpose) { @@ -91,37 +111,106 @@ DNNLBatchDotFwd::DNNLBatchDotFwd(const DotParam& param, } }; - dnnl::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a); - dnnl::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b); + dnnl::memory::desc data_md = GetMemoryDesc(inputs[DotIn::lhs], param.transpose_a); + dnnl::memory::desc weights_md = GetMemoryDesc(inputs[DotIn::rhs], param.transpose_b); dnnl::memory::desc out_md({bigDim, data_md.dims()[1], weights_md.dims()[2]}, - get_dnnl_type(outputs[0].dtype()), + get_dnnl_type(outputs[DotOut::out].dtype()), dnnl::memory::format_tag::any); dnnl::matmul::desc fwd_desc(data_md, weights_md, out_md); - fwd_pd = std::make_shared(fwd_desc, mxnet::CpuEngine::Get()->get_engine()); - fwd = std::make_shared(*fwd_pd); + if (param.quantized) { + auto attrs = GetQuantizationAttributes(param, inputs, outputs); + fwd_pd = std::make_shared( + fwd_desc, attrs, mxnet::CpuEngine::Get()->get_engine()); + + } else { + fwd_pd = std::make_shared(fwd_desc, mxnet::CpuEngine::Get()->get_engine()); + } + + fwd = std::make_shared(*fwd_pd); } -void DNNLBatchDotFwd::Execute(const std::vector& inputs, +void DNNLBatchDotFwd::Execute(const OpContext& ctx, + const DNNLDotParam& param, + const std::vector& inputs, const std::vector& req, const std::vector& outputs) { auto engine = mxnet::CpuEngine::Get()->get_engine(); - auto data = - dnnl::memory(fwd_pd->src_desc(), engine, reinterpret_cast(inputs[0].data().dptr_)); - auto weights = - dnnl::memory(fwd_pd->weights_desc(), engine, reinterpret_cast(inputs[1].data().dptr_)); - dnnl_output_t out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]); + auto lhs = inputs[DotIn::lhs]; + auto rhs = inputs[DotIn::rhs]; + // Created primitive descriptor assumes that both inputs are in default format + if (lhs.IsDNNLData()) + lhs = lhs.Reorder2Default(); + if (rhs.IsDNNLData()) + rhs = rhs.Reorder2Default(); + + auto lhs_mem = + dnnl::memory(fwd_pd->src_desc(), engine, reinterpret_cast(lhs.data().dptr_)); + auto rhs_mem = + dnnl::memory(fwd_pd->weights_desc(), engine, reinterpret_cast(rhs.data().dptr_)); + dnnl_output_t out_mem = CreateDNNLMem( + outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], &inputs[DotIn::lhs]); dnnl_args_map_t args = { - {DNNL_ARG_SRC, data}, - {DNNL_ARG_WEIGHTS, weights}, + {DNNL_ARG_SRC, lhs_mem}, + {DNNL_ARG_WEIGHTS, rhs_mem}, {DNNL_ARG_DST, *out_mem.second}, }; DNNLStream::Get()->RegisterPrimArgs(*fwd, args); CommitOutput(outputs[0], out_mem); DNNLStream::Get()->Submit(); + + if (param.quantized && !param.enable_float_output) { + mshadow::Stream* s = ctx.get_stream(); + float min_output; + float max_output; + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + min_output = param.min_calib_range.value(); + max_output = param.max_calib_range.value(); + } else { + if (inputs[DotIn::lhs].dtype() == mshadow::kInt8) { + mxnet_op::Kernel::Launch( + s, + 1, + &min_output, + &max_output, + inputs[DotIn::rhs_min].data().dptr(), + inputs[DotIn::rhs_max].data().dptr(), + inputs[DotIn::lhs_min].data().dptr(), + inputs[DotIn::lhs_max].data().dptr()); + } else { + mxnet_op::Kernel::Launch( + s, + 1, + &min_output, + &max_output, + inputs[DotIn::rhs_min].data().dptr(), + inputs[DotIn::rhs_max].data().dptr(), + inputs[DotIn::lhs_min].data().dptr(), + inputs[DotIn::lhs_max].data().dptr()); + } + } + + float* min_output_ptr = outputs[DotOut::out_min].data().dptr(); + float* max_output_ptr = outputs[DotOut::out_max].data().dptr(); + *min_output_ptr = min_output; + *max_output_ptr = max_output; + } } } // namespace op } // namespace mxnet + +namespace std { +template <> +struct hash { + size_t operator()(const mxnet::op::DNNLDotParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.transpose_a); + ret = dmlc::HashCombine(ret, val.transpose_b); + ret = dmlc::HashCombine(ret, val.quantized); + return ret; + } +}; +} // namespace std #endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/nn/dnnl/dnnl_ops-inl.h b/src/operator/nn/dnnl/dnnl_ops-inl.h index 8816c3c1f659..d1b6f5928141 100644 --- a/src/operator/nn/dnnl/dnnl_ops-inl.h +++ b/src/operator/nn/dnnl/dnnl_ops-inl.h @@ -159,6 +159,7 @@ void DNNLConcatBackward(const nnvm::NodeAttrs& attrs, const std::vector& outputs); /* For batch dot */ +template void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc new file mode 100644 index 000000000000..612629da6ccf --- /dev/null +++ b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_batch_dot.cc + * \brief DNNL (Quantized) batch_dot operator based on subgraph + */ + +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include + +#include "../../nn/dnnl/dnnl_base-inl.h" +#include "../../nn/dnnl/dnnl_batch_dot-inl.h" +#include "../../nn/dnnl/dnnl_ops-inl.h" +#include "../../quantization/quantization_utils.h" +#include "../../tensor/matrix_op-inl.h" +#include "../common.h" + +namespace mxnet { +namespace op { + +bool DNNLBatchDotShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const DNNLDotParam& param = nnvm::get(attrs.parsed); + mxnet::ShapeVector base_in_shapes; + mxnet::ShapeVector base_out_shapes; + const size_t base_num_inputs = 2; + + base_out_shapes.push_back(out_shapes->at(DotOut::out)); + for (int i = 0; i < base_num_inputs; ++i) { + base_in_shapes.push_back(in_shapes->at(i)); + } + BatchDotShape(attrs, &base_in_shapes, &base_out_shapes); + + for (size_t i = 0; i < in_shapes->size(); ++i) { + if (i < base_in_shapes.size()) { + in_shapes->at(i) = base_in_shapes[i]; + } else { + SHAPE_ASSIGN_CHECK(*in_shapes, i, mshadow::Shape1(1)); + } + } + + out_shapes->at(DotOut::out) = base_out_shapes[DotOut::out]; + if (param.quantized && !param.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shapes, DotOut::out_min, mshadow::Shape1(1)); + SHAPE_ASSIGN_CHECK(*out_shapes, DotOut::out_max, mshadow::Shape1(1)); + } + + return true; +} + +bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs, + std::vector* in_types, + std::vector* out_types) { + const DNNLDotParam& param = nnvm::get(attrs.parsed); + const size_t base_num_inputs = 2; + if (param.quantized) { + CHECK(in_types->at(DotIn::lhs) == mshadow::kInt8 || in_types->at(DotIn::lhs) == mshadow::kUint8) + << "Quantized batch-dot lhs only supports int8/uint8 input, while " + << in_types->at(DotIn::lhs) << " is given."; + CHECK(in_types->at(DotIn::rhs) == mshadow::kInt8 || in_types->at(DotIn::rhs) == mshadow::kUint8) + << "Quantized batch-dot rhs only supports int8 input, while " << in_types->at(DotIn::rhs) + << " is given."; + + for (size_t i = base_num_inputs; i < in_types->size(); ++i) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } + + if (param.enable_float_output) { + TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32); + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kInt8); + } else { + TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kInt32); + } + TYPE_ASSIGN_CHECK(*out_types, DotOut::out_min, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, DotOut::out_max, mshadow::kFloat32); + } + } else { + TYPE_ASSIGN_CHECK(*in_types, DotIn::lhs, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_types, DotIn::rhs, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32); + } + + return true; +} + +inline static bool DNNLBatchDotStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} + +NNVM_REGISTER_OP(_sg_onednn_batch_dot) + .describe(R"code(_sg_onednn_batch_dot)code" ADD_FILELINE) + .set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + // two normal inputs + min/max for quantized version + return param.quantized ? 6 : 2; + }) + .set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + return (param.quantized && !param.enable_float_output) ? 3 : 1; + }) + .set_attr_parser(ParamParser) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return std::vector{ + "lhs", "rhs", "min_lhs", "max_lhs", "min_rhs", "max_rhs"}; + } else { + return std::vector{"lhs", "rhs"}; + } + }) + .set_attr( + "FListOutputNames", + [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return std::vector{"output", "min_output", "max_output"}; + } else { + return std::vector{"output"}; + } + }) + .set_attr("FInferShape", DNNLBatchDotShape) + .set_attr("FInferType", DNNLBatchDotType) + .set_attr("FInferStorageType", DNNLBatchDotStorageType) + .set_attr("FComputeEx", DNNLBatchDotForward) + .set_attr("TIsDNNL", true) + .set_attr("FGradient", MakeZeroGradNodes) + .set_attr("FQuantizable", + [](const NodeAttrs& attrs) { return QuantizeType::kMust; }) + .set_attr("FQuantizedOp", + [](const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_sg_onednn_batch_dot"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["quantized"] = "True"; + + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }) + .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_ONEDNN == 1 diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot_property.h b/src/operator/subgraph/dnnl/dnnl_batch_dot_property.h new file mode 100644 index 000000000000..d2f33aa1cc5a --- /dev/null +++ b/src/operator/subgraph/dnnl/dnnl_batch_dot_property.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_BATCH_DOT_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_BATCH_DOT_PROPERTY_H_ +#if MXNET_USE_ONEDNN == 1 + +#include +#include + +#include "../../tensor/dot-inl.h" +#include "../common.h" + +#include "dnnl_subgraph_base-inl.h" + +namespace mxnet { +namespace op { + +class SgDNNLBatchDotSelector : public SubgraphSelector { + public: + bool Select(const nnvm::Node& n) override { + return n.op() && n.op()->name == "batch_dot"; + } + + bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override { + return false; + } +}; + +class SgDNNLBatchDotProperty : public SubgraphProperty { + public: + static SubgraphPropertyPtr Create() { + static const std::string& name = "DNNL Batch Dot optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_BATCH_DOT_FUSE", 0)) { + property->SetAttr("disable", true); + } + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr n = nnvm::Node::Create(); + + std::ostringstream node_name; + node_name << "sg_dnnl_batch_dot_" << std::to_string(subgraph_id); + + DotParam param; + DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) { + if (node->op() && node->op()->name == "batch_dot") { + param = nnvm::get(node->attrs.parsed); + } + }); + + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get("_sg_onednn_batch_dot"); + CHECK(n->attrs.op); + n->attrs.subgraphs.emplace_back(std::make_shared(sym)); + n->attrs.dict["transpose_a"] = std::to_string(param.transpose_a); + n->attrs.dict["transpose_b"] = std::to_string(param.transpose_b); + n->attrs.dict["quantized"] = "False"; + n->op()->attr_parser(&(n->attrs)); + + return n; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared(); + return selector; + } +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_BATCH_DOT_PROPERTY_H_ diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h similarity index 82% rename from src/operator/subgraph/dnnl/dnnl_transformer_post_quantize_property.h rename to src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h index 7528de54083d..6fbd97fd1f56 100644 --- a/src/operator/subgraph/dnnl/dnnl_transformer_post_quantize_property.h +++ b/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ -#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_ #if MXNET_USE_ONEDNN == 1 #include @@ -31,7 +31,7 @@ namespace mxnet { namespace op { -class SgDNNLTransformerPostQuantizeSelector : public SubgraphSelector { +class SgDNNLMatmulPostQuantizeSelector : public SubgraphSelector { public: /*! \brief pattern match status */ enum SelectStatus { @@ -48,12 +48,13 @@ class SgDNNLTransformerPostQuantizeSelector : public SubgraphSelector { std::vector matched_list; public: - explicit SgDNNLTransformerPostQuantizeSelector(const bool dis_all, const bool dis_float_output) + explicit SgDNNLMatmulPostQuantizeSelector(const bool dis_all, const bool dis_float_output) : disable_all(dis_all), disable_float_output(dis_float_output) {} bool Select(const nnvm::Node& n) override { if ((!disable_all) && (n.op() == Op::Get("_sg_onednn_selfatt_qk") || - n.op() == Op::Get("_sg_onednn_selfatt_valatt"))) { + n.op() == Op::Get("_sg_onednn_selfatt_valatt") || + n.op() == Op::Get("_sg_onednn_batch_dot"))) { status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); @@ -121,22 +122,22 @@ class SgDNNLTransformerPostQuantizeSelector : public SubgraphSelector { void Reset() override { CHECK_GE(matched_list.size(), 1); - auto new_selector = SgDNNLTransformerPostQuantizeSelector(disable_all, disable_float_output); + auto new_selector = SgDNNLMatmulPostQuantizeSelector(disable_all, disable_float_output); new_selector.Select(*matched_list[0]); *this = new_selector; } }; -class SgDNNLTransformerPostQuantizeProperty : public SubgraphProperty { +class SgDNNLMatmulPostQuantizeProperty : public SubgraphProperty { public: - SgDNNLTransformerPostQuantizeProperty() { - disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_DNNL_QTRANSFORMER_FUSE_ALL", false); - disable_float_output = dmlc::GetEnv("MXNET_DISABLE_DNNL_QTRANSFORMER_FLOAT_OUTPUT", false); + SgDNNLMatmulPostQuantizeProperty() { + disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FUSE_ALL", false); + disable_float_output = dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FLOAT_OUTPUT", false); } static SubgraphPropertyPtr Create() { - static const std::string& name = "DNNL Transformer post-quantization optimization pass"; - auto property = std::make_shared(); + static const std::string& name = "DNNL Matmul post-quantization optimization pass"; + auto property = std::make_shared(); property->SetAttr("property_name", name); property->SetAttr("inference_only", true); return property; @@ -152,7 +153,8 @@ class SgDNNLTransformerPostQuantizeProperty : public SubgraphProperty { if (node->is_variable()) return; if (node->op() == Op::Get("_sg_onednn_selfatt_qk") || - node->op() == Op::Get("_sg_onednn_selfatt_valatt")) { + node->op() == Op::Get("_sg_onednn_selfatt_valatt") || + node->op() == Op::Get("_sg_onednn_batch_dot")) { interleaved_node = node; } else if (node->op() == Op::Get("_contrib_requantize")) { requantize_node = node; @@ -183,8 +185,8 @@ class SgDNNLTransformerPostQuantizeProperty : public SubgraphProperty { } SubgraphSelectorPtr CreateSubgraphSelector() const override { - auto selector = std::make_shared(disable_fuse_all, - disable_float_output); + auto selector = + std::make_shared(disable_fuse_all, disable_float_output); return selector; } @@ -197,4 +199,4 @@ class SgDNNLTransformerPostQuantizeProperty : public SubgraphProperty { } // namespace mxnet #endif // if MXNET_USE_ONEDNN == 1 -#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc index de2ac27dad9e..4a5f6a6d129f 100644 --- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc +++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc @@ -19,14 +19,15 @@ #if MXNET_USE_ONEDNN == 1 +#include "dnnl_batch_dot_property.h" #include "dnnl_bn_relu_property.h" #include "dnnl_conv_property.h" #include "dnnl_elemwisemul_post_quantize_property.h" #include "dnnl_fc_post_quantize_property.h" #include "dnnl_fc_property.h" +#include "dnnl_matmul_post_quantize_property.h" #include "dnnl_post_quantize_align_scale_property.h" #include "dnnl_post_quantize_property.h" -#include "dnnl_transformer_post_quantize_property.h" #include "dnnl_transformer_qk_property.h" #include "dnnl_transformer_valatt_property.h" @@ -42,20 +43,21 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBNReLUProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty); MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context", Context::CPU()); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLConvProperty).set_attr("quantize", true); - MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCProperty).set_attr("quantize", true); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerQKProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerValAttProperty); - +MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty) + .set_attr("quantize", true); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeAlignScaleProperty); -MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerPostQuantizeProperty) +MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLMatmulPostQuantizeProperty) .set_attr("quantize", true); } // namespace op diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h index e0844f7a7e5f..3be675cfc45d 100644 --- a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h +++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h @@ -136,7 +136,7 @@ class SgDNNLTransformerQKSelector : public SubgraphSelector { } std::vector Filter(const std::vector& candidates) override { - if (status_ == kFail) { + if (status_ != kSuccess) { return std::vector(0); } else { std::vector ret; diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 839b84c9c4d6..863ef28598ec 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -1516,14 +1516,15 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs, }); } +template inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - mxnet::TShape& lshape = (*in_attrs)[0]; - mxnet::TShape& rshape = (*in_attrs)[1]; + const ParamType& param = nnvm::get(attrs.parsed); + mxnet::TShape& lshape = (*in_attrs)[0]; + mxnet::TShape& rshape = (*in_attrs)[1]; // return false if lhs and rhs both have fully unknown shape if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false; @@ -1564,16 +1565,4 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -namespace std { -template <> -struct hash { - size_t operator()(const mxnet::op::DotParam& val) { - size_t ret = 0; - ret = dmlc::HashCombine(ret, val.transpose_a); - ret = dmlc::HashCombine(ret, val.transpose_b); - ret = dmlc::HashCombine(ret, val.forward_stype); - return ret; - } -}; -} // namespace std #endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_ diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 9a19d0c6e754..defed11eb3ea 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -26,6 +26,7 @@ #if MXNET_USE_ONEDNN == 1 #include "./../nn/dnnl/dnnl_base-inl.h" #include "./../nn/dnnl/dnnl_ops-inl.h" +#include "./../nn/dnnl/dnnl_batch_dot-inl.h" #endif // MXNET_USE_ONEDNN namespace mxnet { @@ -123,7 +124,7 @@ static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { if (SupportDNNLBatchDot(inputs, outputs[0])) { DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - DNNLRun(DNNLBatchDotForward, attrs, ctx, inputs, req, outputs); + DNNLRun(DNNLBatchDotForward, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(BatchDotForward_, attrs, ctx, inputs, req, outputs); return; } @@ -163,7 +164,7 @@ which is computed by:: [](const NodeAttrs& attrs) { return std::vector{"lhs", "rhs"}; }) - .set_attr("FInferShape", BatchDotShape) + .set_attr("FInferShape", BatchDotShape) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { diff --git a/tests/python/dnnl/subgraphs/test_transformer_subgraph.py b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py similarity index 67% rename from tests/python/dnnl/subgraphs/test_transformer_subgraph.py rename to tests/python/dnnl/subgraphs/test_matmul_subgraph.py index 0c24bc26cfc5..b0628b7b7d6a 100644 --- a/tests/python/dnnl/subgraphs/test_transformer_subgraph.py +++ b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py @@ -67,14 +67,13 @@ def forward(self, x, mask): net.hybridize() ref_out = net(in_data, mask) - fused_net.optimize_for(in_data, mask, backend="DNNL") + fused_net.optimize_for(in_data, mask, backend="ONEDNN") out = fused_net(in_data, mask) mx.nd.waitall() for i in range(len(out)): assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy()) - calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1) qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto', exclude_layers=None, @@ -92,3 +91,51 @@ def forward(self, x, mask): max_range = np.max(ref_out[i].asnumpy()) atol = 0.1 * max(abs(min_range), abs(max_range)) assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) + +@use_np +@pytest.mark.parametrize('batch_size', [1, 32]) +@pytest.mark.parametrize('seq_length', [124, 384]) +@pytest.mark.parametrize('units', [256, 768]) +@pytest.mark.parametrize('num_heads', [4, 8]) +def test_batch_dot(batch_size, seq_length, units, num_heads): + class BatchDotBlock(nn.HybridBlock): + def __init__(self, **kwargs): + super(BatchDotBlock, self).__init__(**kwargs) + + def forward(self, lhs, rhs): + x = mx.npx.batch_dot(lhs, rhs) + return x + + lhs_data = mx.np.random.uniform(low=-1, high=1, size=[batch_size, units, seq_length], dtype='float32') + rhs_data = mx.np.random.uniform(low=-1, high=1, size=[batch_size, seq_length, seq_length], dtype='float32') + + net = BatchDotBlock() + net.initialize() + fused_net = net + net.hybridize() + ref_out = net(lhs_data, rhs_data) + + fused_net.optimize_for(lhs_data, rhs_data, backend="ONEDNN") + out = fused_net(lhs_data, rhs_data) + mx.nd.waitall() + + for i in range(len(out)): + assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy()) + + calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(lhs_data, rhs_data), batch_size=1) + qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto', + exclude_layers=None, + exclude_layers_match=None, + calib_data=calib_data, + calib_mode='naive', + num_calib_batches=1, + ctx=mx.cpu()) + + qout = qnet(lhs_data, rhs_data) + mx.nd.waitall() + + for i in range(len(ref_out)): + min_range = np.min(ref_out[i].asnumpy()) + max_range = np.max(ref_out[i].asnumpy()) + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.1)