diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index 9d8ff4e975d8..d942051c0398 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -618,6 +618,8 @@ FP32_FUNCS.extend([ '_sg_mkldnn_conv', '_sg_mkldnn_fully_connected', + '_sg_mkldnn_selfatt_qk', + '_sg_mkldnn_selfatt_valatt', ]) # Functions that have to be cast to FP32 only for diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 6ea74fe7f97d..0cefee7316c2 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -855,7 +855,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize while True: try: network(*data_nd) - except TypeError as err: + except (ValueError, TypeError) as err: if logger: logger.warning(err) logger.warning("Deduced input data descriptors failed to run forward pass." diff --git a/src/operator/subgraph/mkldnn/mkldnn_common.h b/src/operator/subgraph/mkldnn/mkldnn_common.h index c06f3f939b4f..d914e07c224d 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_common.h +++ b/src/operator/subgraph/mkldnn/mkldnn_common.h @@ -28,6 +28,7 @@ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_ #if MXNET_USE_ONEDNN == 1 #include +#include "../../numpy/np_matrix_op-inl.h" namespace mxnet { namespace op { @@ -86,7 +87,7 @@ static std::vector GetWeightScales(const NDArray &weight, const NDArray * return weight_scales; } -static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias, +static inline void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias, const mkldnn::memory::desc &weight_md, const mkldnn::memory::desc *bias_md, const int num_group, float data_scale, @@ -131,6 +132,34 @@ static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bi if (has_bias && data_scale) *bias = new_bias; } + +static inline bool CheckReshapeConditions(const nnvm::Node& node, const index_t out_index) { + const index_t split_output_index = node.inputs[0].index; + if (split_output_index != out_index) + return false; + + const auto &reshape_param = nnvm::get(node.attrs.parsed); + const auto newshape = reshape_param.newshape; + + if (newshape.ndim() != 4 || !(newshape[0] == newshape[1] && newshape[0] == -2)) + return false; + + return true; +} + +static inline bool CheckSwapAxisConditions(const nnvm::Node& node) { + auto params = node.attrs.dict; + int dim1 = 0, dim2 = 0; + if (params.count("dim1") && params.count("dim2")) { + dim1 = std::stoi(params.at("dim1")); + dim2 = std::stoi(params.at("dim2")); + } else { + return false; + } + + return ((dim1 == 1 && dim2 == 2) || (dim1 == 2 && dim2 == 1)); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index 5a6223f5e57d..3a8341f9d1bd 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -26,6 +26,9 @@ #include "mkldnn_fc_post_quantize_property.h" #include "mkldnn_elemwisemul_post_quantize_property.h" #include "mkldnn_post_quantize_align_scale_property.h" +#include "mkldnn_transformer_qk_property.h" +#include "mkldnn_transformer_valatt_property.h" +#include "mkldnn_transformer_post_quantize_property.h" namespace mxnet { namespace op { @@ -37,6 +40,8 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN) MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNBNReLUProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerQKProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerValAttProperty); MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE) .set_attr("context", Context::CPU()); @@ -46,11 +51,15 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty) MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty) .set_attr("quantize", true); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerQKProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerValAttProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty) +.set_attr("quantize", true); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h new file mode 100644 index 000000000000..f0e80350e801 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ + +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" + + +namespace mxnet { +namespace op { + +struct MKLDNNSelfAttParam : public dmlc::Parameter { + int heads; + bool quantized; + bool enable_float_output; + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + DMLC_DECLARE_PARAMETER(MKLDNNSelfAttParam) { + DMLC_DECLARE_FIELD(heads) + .describe("Set number of heads."); + DMLC_DECLARE_FIELD(quantized).set_default(false) + .describe("Whether it's a quantized self attention matmul operator."); + DMLC_DECLARE_FIELD(enable_float_output).set_default(false) + .describe("Whether to enable float32 output."); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized self-attention op to calculate primitive scale."); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized self-attention op to calculate primitive scale."); + } +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc new file mode 100644 index 000000000000..91529bc0a652 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -0,0 +1,763 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include +#include "../common.h" +#include "./mkldnn_transformer-inl.h" +#include "../../contrib/transformer-inl.h" +#include "../../tensor/elemwise_unary_op.h" + +#include "../../quantization/quantization_utils.h" + +// 3 tensors within one (queries key values) = +#define QKV_NUM 3 + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(MKLDNNSelfAttParam); + +static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in batch-seq_length-proj_dim, " + << "but the given tensor is " << qkv_shape.ndim() << "D"; + + if (params.quantized) { + CHECK_EQ(in_shape->size(), 3U) << "Input: [queries_keys_values, min_qkv, max_qkv] " + << "- currently have " << in_shape->size() << " inputs"; + + SHAPE_ASSIGN_CHECK(*in_shape, 1, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*in_shape, 2, mxnet::TShape({1})); + + out_shape->resize(3); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({qkv_shape[0], params.heads, qkv_shape[1], qkv_shape[1]})); + if (!params.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output + } + } else { + CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] - currently have " + << in_shape->size() << " inputs"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({qkv_shape[0], params.heads, qkv_shape[1], qkv_shape[1]})); + } + + return true; +} + +static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + const auto& params = nnvm::get(attrs.parsed); + if (params.quantized) { + CHECK_EQ(in_types->size(), 3U); + + CHECK(in_types->at(0) == mshadow::kInt8) + << "QuantizedSelfAttentionQK only supports int8 input, while " + << in_types->at(0) << " is given."; + + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); + + if (params.enable_float_output) { + CHECK_EQ(out_types->size(), 1U); + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); + } else { + CHECK_EQ(out_types->size(), 3U); + if (params.min_calib_range.has_value() && params.max_calib_range.has_value()) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); + } + } else { + CHECK_EQ(in_types->size(), 1U); + CHECK_EQ(out_types->size(), 1U); + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); + } + + return true; +} + + +class SgMKLDNNSelfAttQKOp { + public: + explicit SgMKLDNNSelfAttQKOp(const nnvm::NodeAttrs &attrs) : + param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn self attention qk only supports " + "inference computation."; + } + + void Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + bool IsInitialized() { + return initialized_; + } + + private: + bool initialized_{false}; + MKLDNNSelfAttParam param_; + mkldnn_args_map_t args_; + std::shared_ptr fwd_; + std::shared_ptr cached_query_mem_; + std::shared_ptr cached_key_mem_; + std::shared_ptr cached_out_mem_; + float min_data_; + float max_data_; + float min_output_; + float max_output_; + float data_scale_{0.0f}; +}; + +static OpStatePtr CreateSgMKLDNNSelfAttQKState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +static void SgMKLDNNSelfAttQKForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNSelfAttQKOp &op = state_pointer.get_state(); + if (!op.IsInitialized()) { + op.Initialize(ctx, inputs, req, outputs); + } + op.Forward(ctx, inputs, req, outputs); +} + +static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} + +void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mkldnn; + + const auto qkv_tensor = inputs[0]; + const auto out_tensor = outputs[0]; + + const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype()); + + const memory::dim heads = param_.heads; + const memory::dim sequences = inputs[0].shape()[0]; + const memory::dim qkv_seq_len = inputs[0].shape()[1]; + const memory::dim output_lin_dim = inputs[0].shape()[2]; + const memory::dim embed_dim = output_lin_dim / QKV_NUM; + const memory::dim head_dim = embed_dim / heads; + const memory::dim batch_stride = output_lin_dim * qkv_seq_len; + + float min_data = 0.0f; + float max_data = 0.0f; + + const auto engine = CpuEngine::Get()->get_engine(); + + memory::dims query_dims = {sequences, heads, qkv_seq_len, head_dim}; + memory::dims key_dims = {sequences, heads, head_dim, qkv_seq_len}; + memory::dims out_dims = {sequences, heads, qkv_seq_len, qkv_seq_len}; + + memory::dims query_strides = {batch_stride, head_dim, output_lin_dim, 1}; + memory::dims key_strides = {batch_stride, head_dim, 1, output_lin_dim}; + + auto query_md = memory::desc(query_dims, qkv_dtype, query_strides); + auto key_md = memory::desc(key_dims, qkv_dtype, key_strides); + + memory::desc out_md; + + float oscale = 1.0f; + if (param_.quantized) { + min_data_ = inputs[1].data().dptr()[0]; + max_data_ = inputs[2].data().dptr()[0]; + data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), min_data_, max_data_); + + if (param_.min_calib_range.has_value() && + param_.max_calib_range.has_value()) { + min_output_ = param_.min_calib_range.value(); + max_output_ = param_.max_calib_range.value(); + oscale = GetQuantizeScale(out_tensor.dtype(), min_output_, max_output_) / + (data_scale_ * data_scale_); + out_md = memory::desc(out_dims, memory::data_type::s8, memory::format_tag::abcd); + } else if (param_.enable_float_output) { + oscale = 1.0f / (data_scale_ * data_scale_); + out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abcd); + } else { + mshadow::Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch( + s, 1, &min_output_, &max_output_, &min_data, &max_data, &min_data, + &max_data); + out_md = dnnl::memory::desc(out_dims, memory::data_type::s32, memory::format_tag::abcd); + } + } else { + out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abcd); + } + + dnnl::primitive_attr attr; + attr.set_output_scales(0, {oscale}); + auto matmul_d = matmul::desc(query_md, key_md, out_md); + auto matmul_pd = matmul::primitive_desc(matmul_d, attr, engine); + fwd_ = std::make_shared(matmul_pd); + + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* query_mem_ptr = inputs[0].data().dptr(); + DType* key_mem_ptr = query_mem_ptr + embed_dim; + cached_query_mem_ = std::make_shared(query_md, engine, query_mem_ptr); + cached_key_mem_ = std::make_shared(key_md, engine, key_mem_ptr); + }); + + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_ = std::make_shared(out_md, engine, outputs[0].data().dptr()); + }); + + args_[DNNL_ARG_SRC] = *cached_query_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_key_mem_; + args_[DNNL_ARG_DST] = *cached_out_mem_; + initialized_ = true; +} + + +void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const size_t output_lin_dim = inputs[0].shape()[2]; + const size_t embed_dim = output_lin_dim / QKV_NUM; + + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* query_mem_ptr = inputs[0].data().dptr(); + DType* key_mem_ptr = query_mem_ptr + embed_dim; + cached_query_mem_->set_data_handle(query_mem_ptr); + cached_key_mem_->set_data_handle(key_mem_ptr); + }); + + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_->set_data_handle(outputs[0].data().dptr()); + }); + + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); + MKLDNNStream::Get()->Submit(); + + if (param_.quantized && !param_.enable_float_output) { + float* output_min = outputs[1].data().dptr(); + float* output_max = outputs[2].data().dptr(); + + *output_min = min_output_; + *output_max = max_output_; + } +} + +nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + auto const ¶m = nnvm::get(attrs.parsed); + node->attrs.op = Op::Get("_sg_mkldnn_selfatt_qk"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["heads"] = std::to_string(param.heads); + node->attrs.dict["quantized"] = "True"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + node->attrs.subgraphs = attrs.subgraphs; + node->op()->attr_parser(&(node->attrs)); + return node; +} + +NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) +.describe(R"code(_sg_mkldnn_selfatt_qk)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return 3; + } else { + return 1; + } +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return 3; + } else { + return 1; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector input_names {"queries_keys_values"}; + if (param.quantized) { + input_names.emplace_back("min_qkv"); + input_names.emplace_back("max_qkv"); + } + return input_names; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector output_names {"output"}; + if (param.quantized && !param.enable_float_output) { + output_names.emplace_back("min_output"); + output_names.emplace_back("max_output"); + } + return output_names; +}) +.set_attr("FInferShape", SgMKLDNNSelfAttShape) +.set_attr("FInferType", SgMKLDNNSelfAttQKInferType) +.set_attr("FInferStorageType", SgMKLDNNSelfAttStorageType) +.set_attr("FCreateOpState", CreateSgMKLDNNSelfAttQKState) +.set_attr("FStatefulComputeEx", SgMKLDNNSelfAttQKForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) +.set_attr("FQuantizedOp", SgMKLDNNSelfAttQKQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") +.add_arguments(MKLDNNSelfAttParam::__FIELDS__()); + +/**********************************_sg_mkldnn_selfatt_valatt**********************************/ + +static bool SgMKLDNNSelfAttValShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + auto att_shape = in_shape->at(0); + auto qkv_shape = in_shape->at(1); + + CHECK_EQ(att_shape.ndim(), 4U) + << "Attention maps should be 4D in batch-heads-seq_length-seq_length, " + << "but the given tensor is " << att_shape.ndim() << "D"; + + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in batch-seq_length-proj_dim, " + << "but the given tensor is " << qkv_shape.ndim() << "D"; + + if (params.quantized) { + CHECK_EQ(in_shape->size(), 6U) << "Input:[attention, queries_keys_values, " + << "attn_min, attn_max, qkv_min, qkv_max] - currently have " + << in_shape->size() << " inputs"; + for (int i = 2; i < 6; i++) { + SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape({1})); + } + + out_shape->resize(3); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({att_shape[0], + att_shape[2], + att_shape[1] * qkv_shape[2] / params.heads / QKV_NUM})); + if (!params.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output + } + } else { + CHECK_EQ(in_shape->size(), 2U) << "Inputs: [queries_keys_values, attention] - currently have " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(1); + auto att_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in batch-seq_length-proj_dim, " + << "but the given tensor is " << qkv_shape.ndim() << "D"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({att_shape[0], + att_shape[2], + att_shape[1] * qkv_shape[2] / params.heads / QKV_NUM})); + return true; + } + + return true; +} + +static bool SgMKLDNNSelfAttValInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + const auto& params = nnvm::get(attrs.parsed); + + if (params.quantized) { + CHECK_EQ(in_types->size(), 6U) << "Input:[attention, queries_keys_values, min_att, max_att, " + "min_qkv, max_qkv] - currently have " + << in_types->size() << " inputs"; + + CHECK(in_types->at(0) == mshadow::kUint8) + << "QuantizedSelfAttentionQK only supports int8/uint8 input, while " + << in_types->at(0) << " is given."; + CHECK(in_types->at(1) == mshadow::kInt8 || + in_types->at(1) == mshadow::kUint8) + << "QuantizedSelfAttentionQK only supports int8/uint8 input, while " + << in_types->at(1) << " is given."; + for (int i = 2; i < 6; i++) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } + + if (params.enable_float_output) { + CHECK_EQ(out_types->size(), 1U); + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); + } else { + CHECK_EQ(out_types->size(), 3U); + if (params.min_calib_range.has_value() && params.max_calib_range.has_value()) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); + } + } else { + CHECK_EQ(in_types->size(), 2U); + CHECK_EQ(out_types->size(), 1U); + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); + } + + return true; +} + +nnvm::ObjectPtr SgMKLDNNSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + auto const ¶m = nnvm::get(attrs.parsed); + node->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["heads"] = std::to_string(param.heads); + node->attrs.dict["quantized"] = "True"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + node->attrs.subgraphs = attrs.subgraphs; + node->op()->attr_parser(&(node->attrs)); + return node; +} + +class MKLDNNSelfAttValAttOp { + public: + explicit MKLDNNSelfAttValAttOp(const nnvm::NodeAttrs &attrs) : + param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn self attention val only supports " + "inference computation."; + } + + void Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + bool IsInitialized() { + return initialized_; + } + + private: + bool initialized_{false}; + MKLDNNSelfAttParam param_; + mkldnn_args_map_t args_; + mkldnn_args_map_t reorder_args; + std::shared_ptr fwd_; + std::shared_ptr reorder_; + std::shared_ptr cached_att_mem_; + std::shared_ptr cached_value_mem_; + std::shared_ptr cached_result_mem_; + std::shared_ptr cached_tmp_mem_; + std::shared_ptr cached_transposed_mem_; // op output + float min_qkv_; + float max_qkv_; + float min_att_; + float max_att_; + float min_output_; + float max_output_; + float qkv_scale_{0.0f}; + float att_scale_{0.0f}; +}; + +static OpStatePtr CreateMKLDNNSelfAttValAttState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +static void MKLDNNSelfAttValAttForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MKLDNNSelfAttValAttOp &op = state_pointer.get_state(); + if (!op.IsInitialized()) { + op.Initialize(ctx, inputs, req, outputs); + } + op.Forward(ctx, inputs, req, outputs); +} + +void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mkldnn; + + const auto attn_tensor = inputs[0]; + const auto qkv_tensor = inputs[1]; + const auto out_tensor = outputs[0]; + + const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype()); + const auto attn_dtype = get_mkldnn_type(attn_tensor.dtype()); + + const memory::dim heads = param_.heads; + const memory::dim sequences = qkv_tensor.shape()[0]; + const memory::dim qkv_seq_len = qkv_tensor.shape()[1]; + const memory::dim output_lin_dim = qkv_tensor.shape()[2]; + const memory::dim embed_dim = output_lin_dim / QKV_NUM; + const memory::dim head_dim = embed_dim / heads; + const memory::dim batch_stride = output_lin_dim * qkv_seq_len; + + const auto engine = CpuEngine::Get()->get_engine(); + + memory::dims attn_dims = {sequences, heads, qkv_seq_len, qkv_seq_len}; + memory::dims value_dims = {sequences, heads, qkv_seq_len, head_dim}; + memory::dims out_dims = {sequences, heads, qkv_seq_len, head_dim}; + + // needed to make transpose on 2nd and 3rd axis with oneDNN + memory::dims transpose_dims = {sequences, heads, qkv_seq_len, head_dim, 1}; + + memory::dims value_strides = {batch_stride, head_dim, output_lin_dim, 1}; + + // for attention tensor just use normal data layout, + // for value tensor we need to use strides as input tensor consists of queries, keys and values + const auto attn_md = memory::desc(attn_dims, attn_dtype, memory::format_tag::abcd); + const auto value_md = memory::desc(value_dims, qkv_dtype, value_strides); + + // result = attn * value + // tmp = result + artificial dimension (1) - same memory ptr as result + // transpose = transposed tmp - output + memory::desc result_md, tmp_md, transpose_md; + + float oscale = 1.0f; + auto result_mkldnn_dtype = memory::data_type::f32; + if (param_.quantized) { + min_att_ = inputs[2].data().dptr()[0]; + max_att_ = inputs[3].data().dptr()[0]; + min_qkv_ = inputs[4].data().dptr()[0]; + max_qkv_ = inputs[5].data().dptr()[0]; + + att_scale_ = GetQuantizeScale(mshadow::kUint8, min_att_, max_att_); + qkv_scale_ = GetQuantizeScale(mshadow::kInt8, min_qkv_, max_qkv_); + + if (param_.min_calib_range.has_value() && + param_.max_calib_range.has_value()) { + min_output_ = param_.min_calib_range.value(); + max_output_ = param_.max_calib_range.value(); + oscale = GetQuantizeScale(out_tensor.dtype(), min_output_, max_output_) / + (att_scale_ * qkv_scale_); + result_mkldnn_dtype = memory::data_type::s8; + } else if (param_.enable_float_output) { + oscale = 1.0f / (att_scale_ * qkv_scale_); + result_mkldnn_dtype = memory::data_type::f32; + } else { + mshadow::Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch( + s, 1, &min_output_, &max_output_, &min_att_, &max_att_, &min_qkv_, + &max_qkv_); + result_mkldnn_dtype = memory::data_type::s32; + } + } else { + result_mkldnn_dtype = memory::data_type::f32; + } + + result_md = memory::desc(out_dims, result_mkldnn_dtype, memory::format_tag::abcd); + tmp_md = memory::desc(transpose_dims, result_mkldnn_dtype, memory::format_tag::abcde); + transpose_md = memory::desc(transpose_dims, result_mkldnn_dtype, memory::format_tag::acbde); + + // multiply by 2 as we need to skip query and key + const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2; + auto att_buffer = inputs[0]; + if (att_buffer.IsMKLDNNData()) + att_buffer = att_buffer.Reorder2Default(); + + MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, { + DType* attention_ptr = att_buffer.data().dptr(); + cached_att_mem_ = std::make_shared(attn_md, engine, attention_ptr); + }); + + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + DType* value_mem_ptr = inputs[1].data().dptr() + value_offset; + cached_value_mem_ = std::make_shared(value_md, engine, value_mem_ptr); + }); + + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_result_mem_ = std::make_shared(result_md, engine); + DType *orig_buf = reinterpret_cast(cached_result_mem_->get_data_handle()); + cached_tmp_mem_ = std::make_shared(tmp_md, engine, orig_buf); + cached_transposed_mem_ = std::make_shared(transpose_md, + engine, + outputs[0].data().dptr()); + }); + + dnnl::primitive_attr attr; + attr.set_output_scales(0, {oscale}); + auto matmul_d = matmul::desc(attn_md, value_md, result_md); + auto matmul_pd = matmul::primitive_desc(matmul_d, attr, engine); + fwd_ = std::make_shared(matmul_pd); + args_[DNNL_ARG_SRC] = *cached_att_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_value_mem_; + args_[DNNL_ARG_DST] = *cached_result_mem_; + + auto reorder_pd = dnnl::reorder::primitive_desc(engine, tmp_md, engine, transpose_md); + reorder_ = std::make_shared(reorder_pd); + reorder_args[DNNL_ARG_SRC] = *cached_tmp_mem_; + reorder_args[DNNL_ARG_DST] = *cached_transposed_mem_; + + initialized_ = true; +} + +void MKLDNNSelfAttValAttOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + // multiply by 2 as we need to skip queries and keys + const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2; + + auto att_buffer = inputs[0]; + if (att_buffer.IsMKLDNNData()) + att_buffer = att_buffer.Reorder2Default(); + + MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, { + DType* attention_ptr = att_buffer.data().dptr(); + cached_att_mem_->set_data_handle(attention_ptr); + }); + + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + DType* qkv_ptr = inputs[1].data().dptr(); + DType* value_mem_ptr = qkv_ptr + value_offset; + cached_value_mem_->set_data_handle(value_mem_ptr); + }); + + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_transposed_mem_->set_data_handle(outputs[0].data().dptr()); + }); + + + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); + MKLDNNStream::Get()->RegisterPrimArgs(*reorder_, reorder_args); + MKLDNNStream::Get()->Submit(); + + if (param_.quantized && !param_.enable_float_output) { + float* output_min = outputs[1].data().dptr(); + float* output_max = outputs[2].data().dptr(); + + *output_min = min_output_; + *output_max = max_output_; + } +} + +NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) +.describe(R"code(_sg_mkldnn_selfatt_valatt)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return 6; + } else { + return 2; + } +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return 3; + } else { + return 1; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector input_names {"attention", "queries_keys_values"}; + if (param.quantized) { + input_names.emplace_back("min_attention"); + input_names.emplace_back("max_attention"); + + input_names.emplace_back("min_qkv"); + input_names.emplace_back("max_qkv"); + } + return input_names; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector output_names {"output"}; + if (param.quantized && !param.enable_float_output) { + output_names.emplace_back("min_output"); + output_names.emplace_back("max_output"); + } + return output_names; +}) +.set_attr("FInferShape", SgMKLDNNSelfAttValShape) +.set_attr("FInferType", SgMKLDNNSelfAttValInferType) +.set_attr("FInferStorageType", SgMKLDNNSelfAttStorageType) +.set_attr("FCreateOpState", CreateMKLDNNSelfAttValAttState) +.set_attr("FStatefulComputeEx", MKLDNNSelfAttValAttForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) +.set_attr("FQuantizedOp", SgMKLDNNSelfAttValAttQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") +.add_arguments(MKLDNNSelfAttParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + +#endif diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h new file mode 100644 index 000000000000..d64e14f25a13 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include "../../quantization/requantize-inl.h" +#include "../common.h" +#include "mkldnn_subgraph_base-inl.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kRequantize, + kSuccess, + }; + + private: + bool disable_all; + bool disable_float_output; + SelectStatus status; + std::vector matched_list; + + public: + explicit SgMKLDNNTransformerPostQuantizeSelector(const bool dis_all, + const bool dis_float_output) + : disable_all(dis_all), + disable_float_output(dis_float_output) {} + + bool Select(const nnvm::Node &n) override { + if ((!disable_all) && + (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") || + n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { + status = disable_all ? kSuccess : kStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + if (std::find(matched_list.begin(), matched_list.end(), &n) != + matched_list.end()) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + } + + status = kSuccess; + return false; + } + + switch (status) { + case kStart: + if (new_node.op() == Op::Get("_contrib_requantize")) { + auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (param.min_calib_range.has_value() && + param.max_calib_range.has_value()) { + matched_list.push_back(&new_node); + status = kRequantize; + return true; + } + } + case kRequantize: + if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { + matched_list.push_back(&new_node); + status = kSuccess; + return true; + } + default: + status = kSuccess; + return false; + } + } + + std::vector Filter( + const std::vector &candidates) override { + if ((status != kSuccess) || (matched_list.size() <= 1)) { + return std::vector(0); + } else { + std::vector ret; + for (auto i : matched_list) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; + } + } + + void Reset() override { + CHECK_GE(matched_list.size(), 1); + auto new_selector = SgMKLDNNTransformerPostQuantizeSelector(disable_all, disable_float_output); + new_selector.Select(*matched_list[0]); + *this = new_selector; + } +}; + +class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { + public: + SgMKLDNNTransformerPostQuantizeProperty() { + disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FUSE_ALL", false); + disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FLOAT_OUTPUT", false); + } + + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Transformer post-quantization optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr interleaved_node = nullptr; + nnvm::ObjectPtr requantize_node = nullptr; + nnvm::ObjectPtr dequantize_node = nullptr; + + DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) { + if (node->is_variable()) return; + if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") || + node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { + interleaved_node = node; + } else if (node->op() == Op::Get("_contrib_requantize")) { + requantize_node = node; + } else if (node->op() == Op::Get("_contrib_dequantize")) { + dequantize_node = node; + } + }); + + CHECK_NOTNULL(interleaved_node); + CHECK_NOTNULL(requantize_node); + auto const &requantize_param = + nnvm::get(requantize_node->attrs.parsed); + CHECK(requantize_param.min_calib_range.has_value()); + CHECK(requantize_param.max_calib_range.has_value()); + + // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range, + // When fusing quantized_interleaved_matmul + requantize + dequantize, + // set dequantize flag to true. + if (dequantize_node != nullptr) { + interleaved_node->attrs.dict["enable_float_output"] = "True"; + } else { + interleaved_node->attrs.dict["min_calib_range"] = + std::to_string(requantize_param.min_calib_range.value()); + interleaved_node->attrs.dict["max_calib_range"] = + std::to_string(requantize_param.max_calib_range.value()); + } + interleaved_node->op()->attr_parser(&(interleaved_node->attrs)); + return interleaved_node; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = + std::make_shared(disable_fuse_all, + disable_float_output); + return selector; + } + + private: + bool disable_fuse_all; + bool disable_float_output; +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_qk_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_qk_property.h new file mode 100644 index 000000000000..098a65c07e2a --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_qk_property.h @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_QK_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_QK_PROPERTY_H_ +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include +#include "../common.h" +#include "../../numpy/np_matrix_op-inl.h" +#include "../../contrib/transformer-inl.h" +#include "../../tensor/matrix_op-inl.h" +#include "mkldnn_common.h" +#include "mkldnn_subgraph_base-inl.h" +#include "mkldnn_transformer-inl.h" + +/* + custom_op + | + _____________|_________________ + | Split | + | / \ | + | _npx_reshape _npx_reshape | + | | | | + | SwapAxis SwapAxis | + | \ / | + | batch_dot | + | | | + |______________________________| +*/ +namespace mxnet { +namespace op { + +class SgMKLDNNTransformerQKSelector : public SubgraphSelector { + enum SelectStatus { + kFail = 0, + kStart, + kFirstSwapAx, + kSecondSwapAx, + kFirstReshape, + kSecondReshape, + kSuccess + }; + +/* + kStart ---> kFirstSwapAx ---> kSecondSwapAx ---> kFirstReshape ---> kSecondReshape ---> kSuccess + Each status except kStart is connected with kFail +*/ + + private: + SelectStatus status_; + std::vector matched_list_; + + public: + bool Select(const nnvm::Node &n, const std::shared_ptr& node_attr) override { + if (n.op() == Op::Get("batch_dot")) { + status_ = kStart; + matched_list_.clear(); + matched_list_.push_back(&n); + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status_ == kFail || status_ == kSuccess || new_node.is_variable()) + return false; + + switch (status_) { + case kStart: + if (new_node.op() == Op::Get("SwapAxis")) { + if (CheckSwapAxisConditions(new_node)) { + status_ = kFirstSwapAx; + matched_list_.push_back(&new_node); + } + return true; + } + case kFirstSwapAx: + if (new_node.op() == Op::Get("SwapAxis")) { + if (CheckSwapAxisConditions(new_node)) { + status_ = kSecondSwapAx; + matched_list_.push_back(&new_node); + return true; + } + } + case kSecondSwapAx: + if (new_node.op() == Op::Get("_npx_reshape")) { + // input to reshape must be first or second output from split + if (CheckReshapeConditions(new_node, 0) || CheckReshapeConditions(new_node, 1)) { + status_ = kFirstReshape; + matched_list_.push_back(&new_node); + return true; + } + } + case kFirstReshape: + if (new_node.op() == Op::Get("_npx_reshape")) { + if (CheckReshapeConditions(new_node, 0) || CheckReshapeConditions(new_node, 1)) { + status_ = kSecondReshape; + matched_list_.push_back(&new_node); + return true; + } + } + case kSecondReshape: + if (new_node.op() == Op::Get("_split_v2")) { + status_ = kSuccess; + return true; + } + default: + status_ = kFail; + return false; + } + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + std::vector Filter( + const std::vector &candidates) override { + if (status_ == kFail) { + return std::vector(0); + } else { + std::vector ret; + for (auto i : matched_list_) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; + } + } + + void Reset() override { + CHECK_GE(matched_list_.size(), 1); + auto new_selector = SgMKLDNNTransformerQKSelector(); + new_selector.Select(*matched_list_[0], nullptr); + *this = new_selector; + } +}; + +class SgMKLDNNTransformerQKProperty : public SubgraphProperty { + public: + SgMKLDNNTransformerQKProperty() {} + + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Transformer optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_TRANSFORMER_OPT", 0)) { + property->SetAttr("disable", true); + } + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr n = nnvm::Node::Create(); + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(last_node); + std::ostringstream node_name; + std::string op_name; + + DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { + if ((node->op() == Op::Get("_npx_reshape"))) { + auto const &reshape_param = + nnvm::get(node->attrs.parsed); + // set heads attribute - all necessary conditions are checked before + n->attrs.dict["heads"] = std::to_string(reshape_param.newshape[2]); + } + }); + + node_name << "_sg_mkldnn_selfatt_qk_" << subgraph_id; + + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get("_sg_mkldnn_selfatt_qk"); + CHECK(n->attrs.op); + n->op()->attr_parser(&(n->attrs)); + return n; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared(); + return selector; + } + + void ConnectSubgraphOutputs( + const nnvm::ObjectPtr n, + std::vector *output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } + + + void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node, + std::vector* input_entries, + std::vector* orig_input_entries) + const override { + subgraph_node->inputs.resize(1); + // split is not part of subgraph, skip split as input and + // connect subgraph input with split input + subgraph_node->inputs[0] = orig_input_entries->at(0).node->inputs[0]; + } +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_QK_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_valatt_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_valatt_property.h new file mode 100644 index 000000000000..845013d31d3d --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_valatt_property.h @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_VALATT_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_VALATT_PROPERTY_H_ +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include +#include "../common.h" +#include "../../contrib/transformer-inl.h" +#include "../../numpy/np_matrix_op-inl.h" +#include "../../swapaxis-inl.h" +#include "../../tensor/matrix_op-inl.h" +#include "mkldnn_common.h" +#include "mkldnn_subgraph_base-inl.h" +#include "mkldnn_transformer-inl.h" + +/* + custom_op + _________________|_________ + | Split | + | | | + | _npx_reshape | + | | | + | custom_op SwapAxis | + | \ / | + | batch_dot | + | | | + | transpose | + | | | + | reshape | + |__________________________| +*/ + +namespace mxnet { +namespace op { + +#define SELFATT_QK "_contrib_interleaved_matmul_selfatt_qk" +#define SELFATT_VALATT "_contrib_interleaved_matmul_selfatt_valatt" + + +bool CheckReshapeConditions(const BiDirectedNode& bi_node) { + const nnvm::Node* rawnode = bi_node.node; + return CheckReshapeConditions(*rawnode, 2); +} + +bool CheckSwapAxisConditions(const BiDirectedNode& bi_node) { + const nnvm::Node* rawnode = bi_node.node; + return CheckSwapAxisConditions(*rawnode); +} + +bool CheckSplitConditions(const BiDirectedNode& bi_node) { + const nnvm::Node* rawnode = bi_node.node; + auto const &split_params = nnvm::get(rawnode->attrs.parsed); + + if (split_params.axis != -1 || split_params.sections != 3 + || split_params.indices.ndim() != 0 || split_params.squeeze_axis != 0) { + return false; + } + + if (bi_node.outputs.size() != 1) { + return false; + } + return true; +} + +class SgMKLDNNTransformerValAttSelector : public SubgraphSelectorV2 { + enum InStatus { + kFail = 0, + kStart, + kSecondStart, + kIgnoreSecond, + kSwapAx, + kReshape, + kSuccess + }; + /* (custom_op) + /---> kSecondStart ---\ + kStart --> > kSwapAx --> kReshape --> kSuccess + \---> kIgnoreSecond ---/ + (SwapAxis recognized - tmp + state to drop second input) + + Each status except kStart is connected with kFail +*/ + + enum OutStatus { + oFail = 0, + oStart, + oTranspose, + oReshape, + oSuccess + }; + + + private: + InStatus in_status_; + OutStatus out_status_; + std::vector matched_list_; + + public: + bool Select(const BiDirectedNode& seed_node, + const std::shared_ptr& node_attr) override { + if (seed_node.node->op() == Op::Get("batch_dot")) { + in_status_ = InStatus::kStart; + out_status_ = OutStatus::oStart; + matched_list_.clear(); + matched_list_.push_back(&seed_node); + return true; + } + return false; + } + + bool SelectInput(const BiDirectedNode &n, const BiDirectedNode &input_node) override { + if (in_status_ == InStatus::kFail || + in_status_ == InStatus::kSuccess || + input_node.node->is_variable()) + return false; + + switch (in_status_) { + case InStatus::kStart: + if (input_node.node->op() == Op::Get("SwapAxis")) { + in_status_ = InStatus::kIgnoreSecond; + matched_list_.push_back(&input_node); + return true; + } else { + in_status_ = InStatus::kSecondStart; + return false; + } + break; + case InStatus::kSecondStart: + if (input_node.node->op() == Op::Get("SwapAxis")) { + if (CheckSwapAxisConditions(input_node)) { + in_status_ = InStatus::kSwapAx; + matched_list_.push_back(&input_node); + return true; + } else { + return false; + } + } + break; + case InStatus::kSwapAx: + if (input_node.node->op() == Op::Get("_npx_reshape")) { + if (CheckReshapeConditions(input_node)) { + in_status_ = InStatus::kReshape; + matched_list_.push_back(&input_node); + return true; + } else { + return false; + } + } + break; + case InStatus::kReshape: + if (input_node.node->op() == Op::Get("_split_v2")) { + if (CheckSplitConditions(input_node)) { + in_status_ = InStatus::kSuccess; + matched_list_.push_back(&input_node); + return true; + } + } + break; + case kIgnoreSecond: + // BFS algorithm - we need to exclude single input of batch_dot (custom_op) + in_status_ = InStatus::kSwapAx; + return false; + default: + in_status_ = InStatus::kFail; + return false; + } + return false; + } + + bool SelectOutput(const BiDirectedNode &n, const BiDirectedNode &output_node) override { + if (out_status_ == OutStatus::oFail || + out_status_ == OutStatus::oSuccess || + output_node.node->is_variable()) + return false; + + switch (out_status_) { + case OutStatus::oStart: + if (output_node.node->op() == Op::Get("_npi_transpose")) { + auto const &transpose_params = + nnvm::get(output_node.node->attrs.parsed); + auto axes = transpose_params.axes; + if (axes.ndim() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 1 && axes[3] == 3) { + out_status_ = OutStatus::oTranspose; + matched_list_.push_back(&output_node); + return true; + } + } + case OutStatus::oTranspose: + if (out_status_ == OutStatus::oTranspose && + output_node.node->op() == Op::Get("_npx_reshape")) { + auto const &reshape_param = nnvm::get(output_node.node->attrs.parsed); + auto newshape = reshape_param.newshape; + if (newshape.ndim() == 3 && + newshape[2] == -1 && + (newshape[0] == newshape[1] && newshape[0] == -2)) { + out_status_ = OutStatus::oSuccess; + matched_list_.push_back(&output_node); + return true; + } + } + default: + out_status_ = OutStatus::oFail; + return false; + } + return false; + } + + std::vector Filter(const std::vector& candidates) override { + if (in_status_ == InStatus::kFail || in_status_ != InStatus::kSuccess || + out_status_ == OutStatus::oFail || out_status_ != OutStatus::oSuccess) { + return std::vector(0); + } else { + std::vector ret; + for (auto i : matched_list_) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; + } + } + + void Reset() override { + CHECK_GE(matched_list_.size(), 1); + auto new_selector = SgMKLDNNTransformerValAttSelector(); + new_selector.Select(*matched_list_[0], nullptr); + *this = new_selector; + } +}; + +class SgMKLDNNTransformerValAttProperty : public SubgraphProperty { + public: + SgMKLDNNTransformerValAttProperty() {} + + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Transformer optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_TRANSFORMER_OPT", 0)) { + property->SetAttr("disable", true); + } + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr n = nnvm::Node::Create(); + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(last_node); + std::ostringstream node_name; + std::string op_name; + DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { + if ((node->op() == Op::Get("_npx_reshape"))) { + auto const &reshape_param = nnvm::get(node->attrs.parsed); + if (reshape_param.newshape.ndim() == 4) + // set heads attribute - all necessary conditions are checked before + n->attrs.dict["heads"] = std::to_string(reshape_param.newshape[2]); + } + }); + node_name << "_sg_mkldnn_selfatt_valatt_" << subgraph_id; + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt"); + CHECK(n->attrs.op); + n->op()->attr_parser(&(n->attrs)); + return n; + } + + SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { + auto selector = std::make_shared(); + return selector; + } +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_VALATT_PROPERTY_H_ diff --git a/tests/python/mkl/subgraphs/test_transformer_subgraph.py b/tests/python/mkl/subgraphs/test_transformer_subgraph.py new file mode 100644 index 000000000000..06daaf2ec24e --- /dev/null +++ b/tests/python/mkl/subgraphs/test_transformer_subgraph.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import copy +import mxnet as mx +import numpy as np +import pytest +from mxnet.contrib import quantization +from mxnet.gluon import nn +from mxnet.test_utils import assert_almost_equal, assert_almost_equal_with_err +from mxnet.util import use_np +import math + +@use_np +@pytest.mark.parametrize('batch_size', [1, 32]) +@pytest.mark.parametrize('seq_length', [124, 384]) +@pytest.mark.parametrize('units', [256, 768]) +@pytest.mark.parametrize('num_heads', [4, 8]) +def test_self_attention(batch_size, seq_length, units, num_heads): + class MultiHeadAttention(nn.HybridBlock): + def __init__(self, units, num_heads, dtype='float32', **kwargs): + super(MultiHeadAttention, self).__init__(**kwargs) + self._units = units + self._num_heads = num_heads + self._fc = nn.Dense(in_units=self._units, units=3*self._units, flatten=False, dtype=dtype) + self._scale = math.sqrt(self._units // self._num_heads) + + def forward(self, x, mask): + x = mx.np.copy(x) + out = self._fc(x) + query, key, value = mx.np.split(out, 3, axis=-1) + query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1)) + key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1)) + value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1)) + scores = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2), mx.np.swapaxes(key, 1, 2), + transpose_b=True) + mask = mx.np.expand_dims(mask, axis=1).astype(np.bool) + attn_weights = mx.npx.masked_softmax(scores, mask=mask.astype(np.bool), + axis=-1, temperature=self._scale) + attn_weights = mx.npx.dropout(attn_weights, p=0.1) + context_vec = mx.npx.batch_dot(attn_weights, + mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3)) + context_vec = mx.npx.reshape(context_vec, (-2, -2, -1)) + + return context_vec + + net = MultiHeadAttention(units, num_heads) + in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], dtype='float32') + mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, seq_length], dtype='int32') + + net.initialize() + fused_net = net + net.hybridize() + ref_out = net(in_data, mask) + + fused_net.optimize_for(in_data, mask, backend="MKLDNN") + out = fused_net(in_data, mask) + mx.nd.waitall() + + for i in range(len(out)): + assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy()) + + + calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1) + qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto', + exclude_layers=None, + exclude_layers_match=None, + calib_data=calib_data, + calib_mode='naive', + num_calib_batches=1, + ctx=mx.cpu()) + + qout = qnet(in_data, mask) + mx.nd.waitall() + + for i in range(len(ref_out)): + min_range = np.min(ref_out[i].asnumpy()) + max_range = np.max(ref_out[i].asnumpy()) + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)