Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN int8 fc #16457

Merged
merged 3 commits into from
Oct 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,15 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);

std::unordered_map<int, mkldnn::memory> args = {
mkldnn_args_map_t args = {
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DST, *out_mem.second},
};
if (!full_param.default_param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
fwd->fwd_pd.bias_desc());
args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
args[MKLDNN_ARG_BIAS] = *bias_mem;
}
MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
CommitOutput(out_data[fullc::kOut], out_mem);
Expand Down Expand Up @@ -298,7 +298,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
std::unordered_map<int, mkldnn::memory> args = {
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
Expand All @@ -317,7 +317,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
ipBwdWeights_pd.diff_weights_desc(),
req[fullc::kWeight]);
std::unordered_map<int, mkldnn::memory> args = {
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
Expand All @@ -328,7 +328,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
ipBwdWeights_pd.diff_bias_desc(),
req[fullc::kBias]);
args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
args[MKLDNN_ARG_DIFF_BIAS] = *in_grad_bias.second;
}
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* \author Ciyong Chen
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
#include "../quantization_utils.h"

Expand Down Expand Up @@ -89,33 +89,40 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
auto &fwd = GetFCFwd(param, is_train, data, weight,
param.no_bias ? nullptr : &quantized_bias, out_md);

auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_desc());
const mkldnn::memory *weight_mem = nullptr;

if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
}
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_primitive_desc(),
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_desc(),
req[fullc::kOut]);
const mkldnn::memory *bias_mem = nullptr;
if (!param.no_bias)
bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc());

fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
mkldnn_args_map_t args = {
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DST, *out_mem.second},
};

const mkldnn::memory *bias_mem = nullptr;
if (!param.no_bias) {
bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_desc());
args[MKLDNN_ARG_BIAS] = *bias_mem;
}

MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), args);
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include <mxnet/ndarray.h>
#include <vector>
Expand Down
10 changes: 5 additions & 5 deletions src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <vector>
#include "quantization_utils.h"
#include "../nn/fully_connected-inl.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
#include "mkldnn/mkldnn_quantized_ops-inl.h"
#endif
Expand Down Expand Up @@ -94,7 +94,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_type->size(), num_inputs * 3);
CHECK_EQ(out_type->size(), 3U);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
<< "QuantizedFullyConnected only supports int8/uint8 input, while "
<< in_type->at(0) << " is given.";
Expand Down Expand Up @@ -124,7 +124,7 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), num_inputs * 3);
CHECK_EQ(out_attrs->size(), 3U);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
return MKLDNNStorageType(attrs, dev_mask, true,
dispatch_mode, in_attrs, out_attrs);
#else
Expand Down Expand Up @@ -292,7 +292,7 @@ void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,
#endif
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down Expand Up @@ -341,7 +341,7 @@ and max thresholds representing the threholds for quantizing the float32 output
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", QuantizedFullyConnectedForwardExCPU)
#endif
Expand Down
2 changes: 1 addition & 1 deletion tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
def test_quantized_fc():
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
if is_test_for_native_cpu():
hasMKL = False;
hasMKL = False
for key in os.environ.keys():
if operator.eq(key, "BUILD_TAG"):
if os.environ['BUILD_TAG'].find("MKL") != -1:
Expand Down