Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Add interleaved batch_dot oneDNN fuses for new GluonNLP mod…
Browse files Browse the repository at this point in the history
…els (#20312)

* Add self attention fuse with oneDNN support

* Make ConvertWeightBias2MKLDNN inline function

* Add new fp32 ops names to amp list

* Switch to forward interface
  • Loading branch information
bgawrych committed Jun 29, 2021
1 parent 38e1416 commit 1d0bdfd
Show file tree
Hide file tree
Showing 10 changed files with 1,706 additions and 2 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,8 @@
FP32_FUNCS.extend([
'_sg_mkldnn_conv',
'_sg_mkldnn_fully_connected',
'_sg_mkldnn_selfatt_qk',
'_sg_mkldnn_selfatt_valatt',
])

# Functions that have to be cast to FP32 only for
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize
while True:
try:
network(*data_nd)
except TypeError as err:
except (ValueError, TypeError) as err:
if logger:
logger.warning(err)
logger.warning("Deduced input data descriptors failed to run forward pass."
Expand Down
31 changes: 30 additions & 1 deletion src/operator/subgraph/mkldnn/mkldnn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_COMMON_H_
#if MXNET_USE_ONEDNN == 1
#include <vector>
#include "../../numpy/np_matrix_op-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -86,7 +87,7 @@ static std::vector<float> GetWeightScales(const NDArray &weight, const NDArray *
return weight_scales;
}

static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias,
static inline void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias,
const mkldnn::memory::desc &weight_md,
const mkldnn::memory::desc *bias_md,
const int num_group, float data_scale,
Expand Down Expand Up @@ -131,6 +132,34 @@ static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bi
if (has_bias && data_scale) *bias = new_bias;
}


static inline bool CheckReshapeConditions(const nnvm::Node& node, const index_t out_index) {
const index_t split_output_index = node.inputs[0].index;
if (split_output_index != out_index)
return false;

const auto &reshape_param = nnvm::get<NumpyXReshapeParam>(node.attrs.parsed);
const auto newshape = reshape_param.newshape;

if (newshape.ndim() != 4 || !(newshape[0] == newshape[1] && newshape[0] == -2))
return false;

return true;
}

static inline bool CheckSwapAxisConditions(const nnvm::Node& node) {
auto params = node.attrs.dict;
int dim1 = 0, dim2 = 0;
if (params.count("dim1") && params.count("dim2")) {
dim1 = std::stoi(params.at("dim1"));
dim2 = std::stoi(params.at("dim2"));
} else {
return false;
}

return ((dim1 == 1 && dim2 == 2) || (dim1 == 2 && dim2 == 1));
}

} // namespace op
} // namespace mxnet

Expand Down
9 changes: 9 additions & 0 deletions src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include "mkldnn_fc_post_quantize_property.h"
#include "mkldnn_elemwisemul_post_quantize_property.h"
#include "mkldnn_post_quantize_align_scale_property.h"
#include "mkldnn_transformer_qk_property.h"
#include "mkldnn_transformer_valatt_property.h"
#include "mkldnn_transformer_post_quantize_property.h"

namespace mxnet {
namespace op {
Expand All @@ -37,6 +40,8 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNBNReLUProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerValAttProperty);

MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
.set_attr("context", Context::CPU());
Expand All @@ -46,11 +51,15 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
.set_attr("quantize", true);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerValAttProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty)
.set_attr("quantize", true);

} // namespace op
} // namespace mxnet
Expand Down
58 changes: 58 additions & 0 deletions src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_

#include "../../mshadow_op.h"
#include "../../mxnet_op.h"


namespace mxnet {
namespace op {

struct MKLDNNSelfAttParam : public dmlc::Parameter<MKLDNNSelfAttParam> {
int heads;
bool quantized;
bool enable_float_output;
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
DMLC_DECLARE_PARAMETER(MKLDNNSelfAttParam) {
DMLC_DECLARE_FIELD(heads)
.describe("Set number of heads.");
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("Whether it's a quantized self attention matmul operator.");
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output.");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized self-attention op to calculate primitive scale.");
DMLC_DECLARE_FIELD(max_calib_range)
.set_default(dmlc::optional<float>())
.describe("The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized self-attention op to calculate primitive scale.");
}
};

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
Loading

0 comments on commit 1d0bdfd

Please sign in to comment.