Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Feature] Add oneDNN support for interleaved_matmul_selfatt_* operato…
Browse files Browse the repository at this point in the history
…rs (fp32/int8) (#20163)

* Add oneDNN code to interleved kernels

* check

* Fix selfattQK subgraph

* fix qk

* Fixes QK

* add test for oneDNN self_att qk

* basic valatt

* add valatt test

* refactor valatt

* fix review

* Change param struct name

* Fix sanity

* Fix sanity

Co-authored-by: grygielski <[email protected]>
  • Loading branch information
bgawrych and grygielski authored Apr 20, 2021
1 parent 4bd7ad5 commit 16d1da9
Show file tree
Hide file tree
Showing 6 changed files with 1,187 additions and 17 deletions.
23 changes: 10 additions & 13 deletions src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "mkldnn_fc_post_quantize_property.h"
#include "mkldnn_elemwisemul_post_quantize_property.h"
#include "mkldnn_post_quantize_align_scale_property.h"
#include "mkldnn_transformer_property.h"
#include "mkldnn_transformer_post_quantize_property.h"

namespace mxnet {
namespace op {
Expand All @@ -35,34 +37,29 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);

#endif // MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 1
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
#endif // MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 1

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty);

MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
.set_attr("context", Context::CPU());

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)
.set_attr("quantize", true);

#endif // MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 1

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
.set_attr("quantize", true);
#endif // MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 1

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty);
#endif // MXNET_USE_MKLDNN == 1

#if MXNET_USE_MKLDNN == 1
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty);
#endif // MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 1
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
58 changes: 58 additions & 0 deletions src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_

#include "../../mxnet_op.h"
#include "../../mshadow_op.h"


namespace mxnet {
namespace op {

struct MKLDNNSelfAttParam : public dmlc::Parameter<MKLDNNSelfAttParam> {
int heads;
bool quantized;
bool enable_float_output;
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
DMLC_DECLARE_PARAMETER(MKLDNNSelfAttParam) {
DMLC_DECLARE_FIELD(heads)
.describe("Set number of heads");
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("Whether it's a quantized InterleavedMatMul operator");
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized InterleavedMatMul op to calculate primitive scale");
DMLC_DECLARE_FIELD(max_calib_range)
.set_default(dmlc::optional<float>())
.describe("The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized InterleavedMatMul op to calculate primitive scale");
}
};

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
Loading

0 comments on commit 16d1da9

Please sign in to comment.